Shortcuts

Introduction to torch.compile

Author: William Wen

torch.compile is the latest method to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic torch.compile usage, and demonstrate the advantages of torch.compile over previous PyTorch compiler solutions, such as TorchScript and FX Tracing.

Contents

Required pip Dependencies

  • torch >= 2.0

  • torchvision

  • numpy

  • scipy

  • tabulate

NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere.

import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:48: UserWarning:

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.

Basic Usage

torch.compile is included in the latest PyTorch. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing torchtriton via pip (pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117" for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to torch.compile. We can then call the returned optimized function in place of the original function.

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 1.6850,  1.9924,  1.7090,  0.0034,  1.1414, -0.1822,  0.4861, -0.0536,
         -0.2252,  1.9398],
        [ 0.3693, -0.0695,  0.1748,  0.3436,  0.1939,  1.5721,  1.9882, -0.2235,
          0.3161,  1.2642],
        [ 0.2480,  1.8793,  1.7152,  1.6772,  1.8881,  1.4748,  1.3466,  1.7763,
          0.7469,  1.0407],
        [-0.1121,  1.6015, -0.0188,  0.2128,  0.5218,  1.9838,  0.8185,  0.5093,
         -0.3603,  0.1793],
        [-1.7890,  1.7532, -0.4040,  0.1222, -0.0029,  1.7975, -0.3877,  0.5123,
          0.1673,  0.1330],
        [ 1.0627,  0.9609,  0.1019,  1.8814,  0.1142, -0.2338, -0.9621,  0.7631,
          0.6506,  0.1853],
        [ 0.4584,  1.7648, -0.0444,  1.9610,  1.5884,  0.7353,  1.2190,  1.3662,
          1.0938, -0.1587],
        [-0.7502,  1.6640,  0.3495,  1.3496,  0.8187,  1.1719,  0.5820,  0.1498,
          0.0885,  0.1036],
        [ 0.3961,  0.6043, -0.0861, -0.3371,  0.8622,  1.4341,  1.2988,  0.5023,
          0.3074,  0.1277],
        [ 0.9748,  0.4117,  1.2616,  1.6314,  0.4693,  0.4092,  0.0401,  1.1196,
          1.2458,  1.3280]])

Alternatively, we can decorate the function.

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
tensor([[ 0.5360,  0.1697, -0.0561,  0.1890, -0.1310,  1.2276,  1.1739,  0.1944,
         -0.1561,  1.6990],
        [ 1.0421,  1.9472,  0.2682,  0.2701,  1.3346,  0.7651,  1.0897,  1.1730,
          0.6161,  0.9223],
        [ 1.5756,  1.5294,  0.0112, -0.1522, -0.7674,  1.8515, -0.2443,  0.3696,
          0.2693,  0.8735],
        [-0.3701,  1.1190,  1.4164,  1.8648,  1.2080,  0.0732,  1.5274,  0.6868,
          1.2440,  1.0715],
        [-1.2454, -0.0159,  0.4315,  0.1317,  1.0530, -1.0603, -0.0532,  0.6661,
          1.7101, -0.2076],
        [-0.7091,  0.7824,  1.7161,  1.2750,  0.6368,  1.2488,  0.4897,  1.2429,
          1.3409,  1.3735],
        [ 0.8345,  0.0653,  0.3462,  1.2383, -0.4092,  1.6438, -0.0962,  0.4011,
          0.2463, -0.5802],
        [ 1.6349,  0.7297,  1.2547, -0.3113,  0.9310,  0.1162,  1.7618,  0.4882,
          0.7640,  0.2930],
        [ 1.1669, -0.7775,  1.2000,  0.6008, -0.2814,  0.5541,  0.5753,  1.4731,
          1.6835,  0.7370],
        [ 1.5087,  0.6195,  0.1153,  1.2966,  1.8815,  1.1678,  1.5686,  1.6018,
          0.2193,  1.3500]])

We can also optimize torch.nn.Module instances.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
tensor([[0.0000, 0.0000, 0.2419, 0.0446, 0.9011, 0.2674, 0.3633, 0.4984, 0.0000,
         0.0988],
        [0.6906, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8490, 0.0000, 0.0000,
         0.5475],
        [0.0852, 0.2762, 0.7441, 0.0000, 0.0000, 0.1820, 0.0000, 0.0000, 0.0000,
         0.0334],
        [0.3024, 0.0077, 1.2572, 0.0000, 0.0000, 0.6520, 0.0000, 0.0000, 0.0000,
         0.8976],
        [0.1998, 0.3333, 0.0000, 0.7803, 0.4202, 0.0915, 0.0000, 1.2543, 0.0000,
         0.4615],
        [0.2487, 0.4187, 0.0000, 0.0000, 0.5124, 0.0000, 0.2512, 0.0000, 0.5850,
         0.0000],
        [0.0000, 0.0048, 0.0000, 0.0000, 0.0000, 0.2287, 0.0000, 0.4841, 0.3915,
         0.0000],
        [0.2017, 0.0000, 0.0896, 1.4135, 0.0593, 0.3788, 0.0000, 0.0000, 0.0000,
         0.4972],
        [0.0000, 0.0000, 1.6580, 0.6414, 0.0000, 0.0000, 0.0000, 0.0000, 0.6491,
         0.7755],
        [0.0000, 0.0000, 0.6442, 0.0260, 0.7456, 0.1000, 0.0000, 0.0000, 0.5366,
         0.1193]], grad_fn=<CompiledFunctionBackward>)

Demonstrating Speedups

Let’s now demonstrate that using torch.compile can speed up real models. We will compare standard eager mode and torch.compile by evaluating and training a torchvision model on random data.

Before we start, we need to define some utility functions.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

First, let’s compare inference.

Note that in the call to torch.compile, we have the additional mode argument, which we will discuss below.

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])
eager: 0.31575347900390627
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:124: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

compile: 84.2865234375

Notice that torch.compile takes a lot longer to complete compared to eager. This is because torch.compile compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn’t change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager.

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager eval time 0: 0.017733631134033204
eager eval time 1: 0.016097280502319337
eager eval time 2: 0.016043008804321288
eager eval time 3: 0.01603993606567383
eager eval time 4: 0.01599897575378418
eager eval time 5: 0.01594265556335449
eager eval time 6: 0.015932415962219237
eager eval time 7: 0.015960063934326172
eager eval time 8: 0.015898624420166017
eager eval time 9: 0.015954943656921388
~~~~~~~~~~
compile eval time 0: 0.6094346313476563
compile eval time 1: 0.008541184425354004
compile eval time 2: 0.008338432312011718
compile eval time 3: 0.008339455604553223
compile eval time 4: 0.008328191757202149
compile eval time 5: 0.008346624374389648
compile eval time 6: 0.0083374080657959
compile eval time 7: 0.008326144218444824
compile eval time 8: 0.008348671913146973
compile eval time 9: 0.008347647666931152
~~~~~~~~~~
(eval) eager median: 0.015979519844055178, compile median: 0.008343039989471435, speedup: 1.9153114289540332x
~~~~~~~~~~

And indeed, we can see that running our model with torch.compile results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

You may also see different speedup results depending on the chosen mode argument. The "reduce-overhead" mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes here.

You may might also notice that the second time we run our model with torch.compile is significantly slower than the other runs, although it is much faster than the first run. This is because the "reduce-overhead" mode runs a few warm-up iterations for CUDA graphs.

For general PyTorch benchmarking, you can try using torch.utils.benchmark instead of the timed function we defined above. We wrote our own timing function in this tutorial to show torch.compile’s compilation latency.

Now, let’s consider comparing training.

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)
eager train time 0: 0.3381759948730469
eager train time 1: 0.04996505737304688
eager train time 2: 0.04853657531738281
eager train time 3: 0.04814438247680664
eager train time 4: 0.04834915161132813
eager train time 5: 0.04864921569824219
eager train time 6: 0.04833587265014649
eager train time 7: 0.048277503967285154
eager train time 8: 0.048293888092041014
eager train time 9: 0.048399360656738284
~~~~~~~~~~
compile train time 0: 209.49559375
compile train time 1: 4.77366162109375
compile train time 2: 0.03038617515563965
compile train time 3: 0.023971839904785155
compile train time 4: 0.02370047950744629
compile train time 5: 0.023814144134521483
compile train time 6: 0.02366464042663574
compile train time 7: 0.024535039901733398
compile train time 8: 0.02367692756652832
compile train time 9: 0.023723007202148438
~~~~~~~~~~
(train) eager median: 0.048374256134033206, compile median: 0.02389299201965332, speedup: 2.0246211146030886x
~~~~~~~~~~

Again, we can see that torch.compile takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the TorchInductor performance dashboard.

Comparison to TorchScript and FX Tracing

We have seen that torch.compile can speed up PyTorch code. Why else should we use torch.compile over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of torch.compile lies in its ability to handle arbitrary Python code with minimal changes to existing code.

One case that torch.compile can handle that other compiler solutions struggle with is data-dependent control flow (the if x.sum() < 0: line below).

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

TorchScript tracing f1 results in silently incorrect results, since only the actual control flow path is traced.

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:274: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

traced 1, 1: True
traced 1, 2: False

FX tracing f1 results in an error due to the presence of data-dependent control flow.

import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 304, in <module>
    torch.fx.symbolic_trace(f1)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1193, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 274, in f1
    if x.sum() < 0:
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 443, in __bool__
    return self.tracer.to_bool(self)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 303, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

If we provide a value for x as we try to FX trace f1, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function.

fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:862: UserWarning:

Was not able to add assertion to guarantee correct input x to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.

fx 1, 1: True
fx 1, 2: False

Now we can see that torch.compile correctly handles data-dependent control flow.

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~

TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used.

In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument y, an int, does not match with the default argument type, torch.Tensor.

def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 347, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor

However, torch.compile is easily able to handle f2.

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
compile 2: True
~~~~~~~~~~

Another case that torch.compile handles well compared to previous compilers solutions is the usage of non-PyTorch functions.

import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong.

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:365: TracerWarning:

Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:366: TracerWarning:

torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

traced 3: False

TorchScript scripting and FX tracing disallow non-PyTorch function calls.

try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 383, in <module>
    torch.jit.script(f3)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/jit/_script.py", line 1395, in script
    fn = torch._C._jit_script_compile(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_jit_internal.py", line 1216, in _try_get_dispatched_fn
    return boolean_dispatched.get(fn)
  File "/opt/conda/envs/py_3.10/lib/python3.10/weakref.py", line 453, in get
    return self.data.get(ref(key),default)
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 388, in <module>
    torch.fx.symbolic_trace(f3)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 1193, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace
    (self.create_arg(fn(*args)),),
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 365, in f3
    x = scipy.fft.dct(x.numpy())
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/realtransforms.py", line 19, in _r2r
    tmp = _asfarray(x)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/scipy/fft/_pocketfft/helper.py", line 89, in _asfarray
    if x.dtype == np.float16:
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 548, in impl
    return tracer.create_proxy('call_function', target, args, kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 187, in create_proxy
    args_ = self.create_arg(args)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 401, in create_arg
    return super().create_arg(a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 258, in create_arg
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 258, in <genexpr>
    return type(a)(self.create_arg(elem) for elem in a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 401, in create_arg
    return super().create_arg(a)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/proxy.py", line 294, in create_arg
    raise NotImplementedError(f"argument of type: {type(a)}")
NotImplementedError: argument of type: <class 'type'>

In comparison, torch.compile is easily able to handle the non-PyTorch function call.

compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))
compile 3: True

TorchDynamo and FX Graphs

One important component of torch.compile is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into FX graphs, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations.

Normally, TorchInductor, another component of torch.compile, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph’s unoptimized forward method.

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])
custom backend called with FX graph:
opcode         name                                               target                                                      args                                                                                                                                                                                                                                                                                                                                                                                                                                                                kwargs
-------------  -------------------------------------------------  ----------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  -----------------
placeholder    l_x_                                               L_x_                                                        ()                                                                                                                                                                                                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_conv0                           L__self___features_conv0                                    (l_x_,)                                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_norm0                           L__self___features_norm0                                    (l__self___features_conv0,)                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_relu0                           L__self___features_relu0                                    (l__self___features_norm0,)                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_pool0                           L__self___features_pool0                                    (l__self___features_relu0,)                                                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_function  concated_features                                  <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0], 1)                                                                                                                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock1_denselayer1_norm1   L__self___features_denseblock1_denselayer1_norm1            (concated_features,)                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock1_denselayer1_relu1   L__self___features_denseblock1_denselayer1_relu1            (l__self___features_denseblock1_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output                                  L__self___features_denseblock1_denselayer1_conv1            (l__self___features_denseblock1_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer1_norm2   L__self___features_denseblock1_denselayer1_norm2            (bottleneck_output,)                                                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock1_denselayer1_relu2   L__self___features_denseblock1_denselayer1_relu2            (l__self___features_denseblock1_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features                                       L__self___features_denseblock1_denselayer1_conv2            (l__self___features_denseblock1_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_1                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features], 1)                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock1_denselayer2_norm1   L__self___features_denseblock1_denselayer2_norm1            (concated_features_1,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer2_relu1   L__self___features_denseblock1_denselayer2_relu1            (l__self___features_denseblock1_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_2                                L__self___features_denseblock1_denselayer2_conv1            (l__self___features_denseblock1_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer2_norm2   L__self___features_denseblock1_denselayer2_norm2            (bottleneck_output_2,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer2_relu2   L__self___features_denseblock1_denselayer2_relu2            (l__self___features_denseblock1_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_2                                     L__self___features_denseblock1_denselayer2_conv2            (l__self___features_denseblock1_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_2                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features, new_features_2], 1)                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock1_denselayer3_norm1   L__self___features_denseblock1_denselayer3_norm1            (concated_features_2,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer3_relu1   L__self___features_denseblock1_denselayer3_relu1            (l__self___features_denseblock1_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_4                                L__self___features_denseblock1_denselayer3_conv1            (l__self___features_denseblock1_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer3_norm2   L__self___features_denseblock1_denselayer3_norm2            (bottleneck_output_4,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer3_relu2   L__self___features_denseblock1_denselayer3_relu2            (l__self___features_denseblock1_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_4                                     L__self___features_denseblock1_denselayer3_conv2            (l__self___features_denseblock1_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_3                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features, new_features_2, new_features_4], 1)                                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock1_denselayer4_norm1   L__self___features_denseblock1_denselayer4_norm1            (concated_features_3,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer4_relu1   L__self___features_denseblock1_denselayer4_relu1            (l__self___features_denseblock1_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_6                                L__self___features_denseblock1_denselayer4_conv1            (l__self___features_denseblock1_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer4_norm2   L__self___features_denseblock1_denselayer4_norm2            (bottleneck_output_6,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer4_relu2   L__self___features_denseblock1_denselayer4_relu2            (l__self___features_denseblock1_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_6                                     L__self___features_denseblock1_denselayer4_conv2            (l__self___features_denseblock1_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_4                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6], 1)                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock1_denselayer5_norm1   L__self___features_denseblock1_denselayer5_norm1            (concated_features_4,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer5_relu1   L__self___features_denseblock1_denselayer5_relu1            (l__self___features_denseblock1_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_8                                L__self___features_denseblock1_denselayer5_conv1            (l__self___features_denseblock1_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer5_norm2   L__self___features_denseblock1_denselayer5_norm2            (bottleneck_output_8,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer5_relu2   L__self___features_denseblock1_denselayer5_relu2            (l__self___features_denseblock1_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_8                                     L__self___features_denseblock1_denselayer5_conv2            (l__self___features_denseblock1_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_5                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6, new_features_8], 1)                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock1_denselayer6_norm1   L__self___features_denseblock1_denselayer6_norm1            (concated_features_5,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock1_denselayer6_relu1   L__self___features_denseblock1_denselayer6_relu1            (l__self___features_denseblock1_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_10                               L__self___features_denseblock1_denselayer6_conv1            (l__self___features_denseblock1_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock1_denselayer6_norm2   L__self___features_denseblock1_denselayer6_norm2            (bottleneck_output_10,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock1_denselayer6_relu2   L__self___features_denseblock1_denselayer6_relu2            (l__self___features_denseblock1_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_10                                    L__self___features_denseblock1_denselayer6_conv2            (l__self___features_denseblock1_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  cat_6                                              <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_pool0, new_features, new_features_2, new_features_4, new_features_6, new_features_8, new_features_10], 1)                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_transition1_norm                L__self___features_transition1_norm                         (cat_6,)                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_transition1_relu                L__self___features_transition1_relu                         (l__self___features_transition1_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition1_conv                L__self___features_transition1_conv                         (l__self___features_transition1_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition1_pool                L__self___features_transition1_pool                         (l__self___features_transition1_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_function  concated_features_6                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock2_denselayer1_norm1   L__self___features_denseblock2_denselayer1_norm1            (concated_features_6,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock2_denselayer1_relu1   L__self___features_denseblock2_denselayer1_relu1            (l__self___features_denseblock2_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_12                               L__self___features_denseblock2_denselayer1_conv1            (l__self___features_denseblock2_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer1_norm2   L__self___features_denseblock2_denselayer1_norm2            (bottleneck_output_12,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer1_relu2   L__self___features_denseblock2_denselayer1_relu2            (l__self___features_denseblock2_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_12                                    L__self___features_denseblock2_denselayer1_conv2            (l__self___features_denseblock2_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_7                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12], 1)                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock2_denselayer2_norm1   L__self___features_denseblock2_denselayer2_norm1            (concated_features_7,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock2_denselayer2_relu1   L__self___features_denseblock2_denselayer2_relu1            (l__self___features_denseblock2_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_14                               L__self___features_denseblock2_denselayer2_conv1            (l__self___features_denseblock2_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer2_norm2   L__self___features_denseblock2_denselayer2_norm2            (bottleneck_output_14,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer2_relu2   L__self___features_denseblock2_denselayer2_relu2            (l__self___features_denseblock2_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_14                                    L__self___features_denseblock2_denselayer2_conv2            (l__self___features_denseblock2_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_8                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14], 1)                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock2_denselayer3_norm1   L__self___features_denseblock2_denselayer3_norm1            (concated_features_8,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock2_denselayer3_relu1   L__self___features_denseblock2_denselayer3_relu1            (l__self___features_denseblock2_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_16                               L__self___features_denseblock2_denselayer3_conv1            (l__self___features_denseblock2_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer3_norm2   L__self___features_denseblock2_denselayer3_norm2            (bottleneck_output_16,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer3_relu2   L__self___features_denseblock2_denselayer3_relu2            (l__self___features_denseblock2_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_16                                    L__self___features_denseblock2_denselayer3_conv2            (l__self___features_denseblock2_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_9                                <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16], 1)                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock2_denselayer4_norm1   L__self___features_denseblock2_denselayer4_norm1            (concated_features_9,)                                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock2_denselayer4_relu1   L__self___features_denseblock2_denselayer4_relu1            (l__self___features_denseblock2_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_18                               L__self___features_denseblock2_denselayer4_conv1            (l__self___features_denseblock2_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer4_norm2   L__self___features_denseblock2_denselayer4_norm2            (bottleneck_output_18,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer4_relu2   L__self___features_denseblock2_denselayer4_relu2            (l__self___features_denseblock2_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_18                                    L__self___features_denseblock2_denselayer4_conv2            (l__self___features_denseblock2_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_10                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18], 1)                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock2_denselayer5_norm1   L__self___features_denseblock2_denselayer5_norm1            (concated_features_10,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer5_relu1   L__self___features_denseblock2_denselayer5_relu1            (l__self___features_denseblock2_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_20                               L__self___features_denseblock2_denselayer5_conv1            (l__self___features_denseblock2_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer5_norm2   L__self___features_denseblock2_denselayer5_norm2            (bottleneck_output_20,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer5_relu2   L__self___features_denseblock2_denselayer5_relu2            (l__self___features_denseblock2_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_20                                    L__self___features_denseblock2_denselayer5_conv2            (l__self___features_denseblock2_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_11                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20], 1)                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock2_denselayer6_norm1   L__self___features_denseblock2_denselayer6_norm1            (concated_features_11,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer6_relu1   L__self___features_denseblock2_denselayer6_relu1            (l__self___features_denseblock2_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_22                               L__self___features_denseblock2_denselayer6_conv1            (l__self___features_denseblock2_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer6_norm2   L__self___features_denseblock2_denselayer6_norm2            (bottleneck_output_22,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer6_relu2   L__self___features_denseblock2_denselayer6_relu2            (l__self___features_denseblock2_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_22                                    L__self___features_denseblock2_denselayer6_conv2            (l__self___features_denseblock2_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_12                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22], 1)                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_denseblock2_denselayer7_norm1   L__self___features_denseblock2_denselayer7_norm1            (concated_features_12,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer7_relu1   L__self___features_denseblock2_denselayer7_relu1            (l__self___features_denseblock2_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_24                               L__self___features_denseblock2_denselayer7_conv1            (l__self___features_denseblock2_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer7_norm2   L__self___features_denseblock2_denselayer7_norm2            (bottleneck_output_24,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer7_relu2   L__self___features_denseblock2_denselayer7_relu2            (l__self___features_denseblock2_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_24                                    L__self___features_denseblock2_denselayer7_conv2            (l__self___features_denseblock2_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_13                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24], 1)                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock2_denselayer8_norm1   L__self___features_denseblock2_denselayer8_norm1            (concated_features_13,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer8_relu1   L__self___features_denseblock2_denselayer8_relu1            (l__self___features_denseblock2_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_26                               L__self___features_denseblock2_denselayer8_conv1            (l__self___features_denseblock2_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer8_norm2   L__self___features_denseblock2_denselayer8_norm2            (bottleneck_output_26,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer8_relu2   L__self___features_denseblock2_denselayer8_relu2            (l__self___features_denseblock2_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_26                                    L__self___features_denseblock2_denselayer8_conv2            (l__self___features_denseblock2_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_14                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26], 1)                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock2_denselayer9_norm1   L__self___features_denseblock2_denselayer9_norm1            (concated_features_14,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer9_relu1   L__self___features_denseblock2_denselayer9_relu1            (l__self___features_denseblock2_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_28                               L__self___features_denseblock2_denselayer9_conv1            (l__self___features_denseblock2_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer9_norm2   L__self___features_denseblock2_denselayer9_norm2            (bottleneck_output_28,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer9_relu2   L__self___features_denseblock2_denselayer9_relu2            (l__self___features_denseblock2_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_28                                    L__self___features_denseblock2_denselayer9_conv2            (l__self___features_denseblock2_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_15                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28], 1)                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock2_denselayer10_norm1  L__self___features_denseblock2_denselayer10_norm1           (concated_features_15,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer10_relu1  L__self___features_denseblock2_denselayer10_relu1           (l__self___features_denseblock2_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_30                               L__self___features_denseblock2_denselayer10_conv1           (l__self___features_denseblock2_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock2_denselayer10_norm2  L__self___features_denseblock2_denselayer10_norm2           (bottleneck_output_30,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer10_relu2  L__self___features_denseblock2_denselayer10_relu2           (l__self___features_denseblock2_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_30                                    L__self___features_denseblock2_denselayer10_conv2           (l__self___features_denseblock2_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_16                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30], 1)                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock2_denselayer11_norm1  L__self___features_denseblock2_denselayer11_norm1           (concated_features_16,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer11_relu1  L__self___features_denseblock2_denselayer11_relu1           (l__self___features_denseblock2_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_32                               L__self___features_denseblock2_denselayer11_conv1           (l__self___features_denseblock2_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock2_denselayer11_norm2  L__self___features_denseblock2_denselayer11_norm2           (bottleneck_output_32,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer11_relu2  L__self___features_denseblock2_denselayer11_relu2           (l__self___features_denseblock2_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_32                                    L__self___features_denseblock2_denselayer11_conv2           (l__self___features_denseblock2_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_17                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30, new_features_32], 1)                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock2_denselayer12_norm1  L__self___features_denseblock2_denselayer12_norm1           (concated_features_17,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer12_relu1  L__self___features_denseblock2_denselayer12_relu1           (l__self___features_denseblock2_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_34                               L__self___features_denseblock2_denselayer12_conv1           (l__self___features_denseblock2_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock2_denselayer12_norm2  L__self___features_denseblock2_denselayer12_norm2           (bottleneck_output_34,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock2_denselayer12_relu2  L__self___features_denseblock2_denselayer12_relu2           (l__self___features_denseblock2_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_34                                    L__self___features_denseblock2_denselayer12_conv2           (l__self___features_denseblock2_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  cat_19                                             <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition1_pool, new_features_12, new_features_14, new_features_16, new_features_18, new_features_20, new_features_22, new_features_24, new_features_26, new_features_28, new_features_30, new_features_32, new_features_34], 1)                                                                                                                                                                                                              {}
call_module    l__self___features_transition2_norm                L__self___features_transition2_norm                         (cat_19,)                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_module    l__self___features_transition2_relu                L__self___features_transition2_relu                         (l__self___features_transition2_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition2_conv                L__self___features_transition2_conv                         (l__self___features_transition2_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition2_pool                L__self___features_transition2_pool                         (l__self___features_transition2_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_function  concated_features_18                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock3_denselayer1_norm1   L__self___features_denseblock3_denselayer1_norm1            (concated_features_18,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer1_relu1   L__self___features_denseblock3_denselayer1_relu1            (l__self___features_denseblock3_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_36                               L__self___features_denseblock3_denselayer1_conv1            (l__self___features_denseblock3_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer1_norm2   L__self___features_denseblock3_denselayer1_norm2            (bottleneck_output_36,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer1_relu2   L__self___features_denseblock3_denselayer1_relu2            (l__self___features_denseblock3_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_36                                    L__self___features_denseblock3_denselayer1_conv2            (l__self___features_denseblock3_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_19                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36], 1)                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock3_denselayer2_norm1   L__self___features_denseblock3_denselayer2_norm1            (concated_features_19,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer2_relu1   L__self___features_denseblock3_denselayer2_relu1            (l__self___features_denseblock3_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_38                               L__self___features_denseblock3_denselayer2_conv1            (l__self___features_denseblock3_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer2_norm2   L__self___features_denseblock3_denselayer2_norm2            (bottleneck_output_38,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer2_relu2   L__self___features_denseblock3_denselayer2_relu2            (l__self___features_denseblock3_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_38                                    L__self___features_denseblock3_denselayer2_conv2            (l__self___features_denseblock3_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_20                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38], 1)                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer3_norm1   L__self___features_denseblock3_denselayer3_norm1            (concated_features_20,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer3_relu1   L__self___features_denseblock3_denselayer3_relu1            (l__self___features_denseblock3_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_40                               L__self___features_denseblock3_denselayer3_conv1            (l__self___features_denseblock3_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer3_norm2   L__self___features_denseblock3_denselayer3_norm2            (bottleneck_output_40,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer3_relu2   L__self___features_denseblock3_denselayer3_relu2            (l__self___features_denseblock3_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_40                                    L__self___features_denseblock3_denselayer3_conv2            (l__self___features_denseblock3_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_21                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40], 1)                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock3_denselayer4_norm1   L__self___features_denseblock3_denselayer4_norm1            (concated_features_21,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer4_relu1   L__self___features_denseblock3_denselayer4_relu1            (l__self___features_denseblock3_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_42                               L__self___features_denseblock3_denselayer4_conv1            (l__self___features_denseblock3_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer4_norm2   L__self___features_denseblock3_denselayer4_norm2            (bottleneck_output_42,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer4_relu2   L__self___features_denseblock3_denselayer4_relu2            (l__self___features_denseblock3_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_42                                    L__self___features_denseblock3_denselayer4_conv2            (l__self___features_denseblock3_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_22                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42], 1)                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock3_denselayer5_norm1   L__self___features_denseblock3_denselayer5_norm1            (concated_features_22,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer5_relu1   L__self___features_denseblock3_denselayer5_relu1            (l__self___features_denseblock3_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_44                               L__self___features_denseblock3_denselayer5_conv1            (l__self___features_denseblock3_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer5_norm2   L__self___features_denseblock3_denselayer5_norm2            (bottleneck_output_44,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer5_relu2   L__self___features_denseblock3_denselayer5_relu2            (l__self___features_denseblock3_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_44                                    L__self___features_denseblock3_denselayer5_conv2            (l__self___features_denseblock3_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_23                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44], 1)                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock3_denselayer6_norm1   L__self___features_denseblock3_denselayer6_norm1            (concated_features_23,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer6_relu1   L__self___features_denseblock3_denselayer6_relu1            (l__self___features_denseblock3_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_46                               L__self___features_denseblock3_denselayer6_conv1            (l__self___features_denseblock3_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer6_norm2   L__self___features_denseblock3_denselayer6_norm2            (bottleneck_output_46,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer6_relu2   L__self___features_denseblock3_denselayer6_relu2            (l__self___features_denseblock3_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_46                                    L__self___features_denseblock3_denselayer6_conv2            (l__self___features_denseblock3_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_24                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46], 1)                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_denseblock3_denselayer7_norm1   L__self___features_denseblock3_denselayer7_norm1            (concated_features_24,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer7_relu1   L__self___features_denseblock3_denselayer7_relu1            (l__self___features_denseblock3_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_48                               L__self___features_denseblock3_denselayer7_conv1            (l__self___features_denseblock3_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer7_norm2   L__self___features_denseblock3_denselayer7_norm2            (bottleneck_output_48,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer7_relu2   L__self___features_denseblock3_denselayer7_relu2            (l__self___features_denseblock3_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_48                                    L__self___features_denseblock3_denselayer7_conv2            (l__self___features_denseblock3_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_25                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48], 1)                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock3_denselayer8_norm1   L__self___features_denseblock3_denselayer8_norm1            (concated_features_25,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer8_relu1   L__self___features_denseblock3_denselayer8_relu1            (l__self___features_denseblock3_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_50                               L__self___features_denseblock3_denselayer8_conv1            (l__self___features_denseblock3_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer8_norm2   L__self___features_denseblock3_denselayer8_norm2            (bottleneck_output_50,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer8_relu2   L__self___features_denseblock3_denselayer8_relu2            (l__self___features_denseblock3_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_50                                    L__self___features_denseblock3_denselayer8_conv2            (l__self___features_denseblock3_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_26                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50], 1)                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock3_denselayer9_norm1   L__self___features_denseblock3_denselayer9_norm1            (concated_features_26,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer9_relu1   L__self___features_denseblock3_denselayer9_relu1            (l__self___features_denseblock3_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_52                               L__self___features_denseblock3_denselayer9_conv1            (l__self___features_denseblock3_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer9_norm2   L__self___features_denseblock3_denselayer9_norm2            (bottleneck_output_52,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer9_relu2   L__self___features_denseblock3_denselayer9_relu2            (l__self___features_denseblock3_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_52                                    L__self___features_denseblock3_denselayer9_conv2            (l__self___features_denseblock3_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_27                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52], 1)                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock3_denselayer10_norm1  L__self___features_denseblock3_denselayer10_norm1           (concated_features_27,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer10_relu1  L__self___features_denseblock3_denselayer10_relu1           (l__self___features_denseblock3_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_54                               L__self___features_denseblock3_denselayer10_conv1           (l__self___features_denseblock3_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer10_norm2  L__self___features_denseblock3_denselayer10_norm2           (bottleneck_output_54,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer10_relu2  L__self___features_denseblock3_denselayer10_relu2           (l__self___features_denseblock3_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_54                                    L__self___features_denseblock3_denselayer10_conv2           (l__self___features_denseblock3_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_28                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54], 1)                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer11_norm1  L__self___features_denseblock3_denselayer11_norm1           (concated_features_28,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer11_relu1  L__self___features_denseblock3_denselayer11_relu1           (l__self___features_denseblock3_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_56                               L__self___features_denseblock3_denselayer11_conv1           (l__self___features_denseblock3_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer11_norm2  L__self___features_denseblock3_denselayer11_norm2           (bottleneck_output_56,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer11_relu2  L__self___features_denseblock3_denselayer11_relu2           (l__self___features_denseblock3_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_56                                    L__self___features_denseblock3_denselayer11_conv2           (l__self___features_denseblock3_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_29                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56], 1)                                                                                                                                                                                                                               {}
call_module    l__self___features_denseblock3_denselayer12_norm1  L__self___features_denseblock3_denselayer12_norm1           (concated_features_29,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer12_relu1  L__self___features_denseblock3_denselayer12_relu1           (l__self___features_denseblock3_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_58                               L__self___features_denseblock3_denselayer12_conv1           (l__self___features_denseblock3_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer12_norm2  L__self___features_denseblock3_denselayer12_norm2           (bottleneck_output_58,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer12_relu2  L__self___features_denseblock3_denselayer12_relu2           (l__self___features_denseblock3_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_58                                    L__self___features_denseblock3_denselayer12_conv2           (l__self___features_denseblock3_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_30                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58], 1)                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock3_denselayer13_norm1  L__self___features_denseblock3_denselayer13_norm1           (concated_features_30,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer13_relu1  L__self___features_denseblock3_denselayer13_relu1           (l__self___features_denseblock3_denselayer13_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_60                               L__self___features_denseblock3_denselayer13_conv1           (l__self___features_denseblock3_denselayer13_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer13_norm2  L__self___features_denseblock3_denselayer13_norm2           (bottleneck_output_60,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer13_relu2  L__self___features_denseblock3_denselayer13_relu2           (l__self___features_denseblock3_denselayer13_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_60                                    L__self___features_denseblock3_denselayer13_conv2           (l__self___features_denseblock3_denselayer13_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_31                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60], 1)                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer14_norm1  L__self___features_denseblock3_denselayer14_norm1           (concated_features_31,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer14_relu1  L__self___features_denseblock3_denselayer14_relu1           (l__self___features_denseblock3_denselayer14_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_62                               L__self___features_denseblock3_denselayer14_conv1           (l__self___features_denseblock3_denselayer14_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer14_norm2  L__self___features_denseblock3_denselayer14_norm2           (bottleneck_output_62,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer14_relu2  L__self___features_denseblock3_denselayer14_relu2           (l__self___features_denseblock3_denselayer14_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_62                                    L__self___features_denseblock3_denselayer14_conv2           (l__self___features_denseblock3_denselayer14_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_32                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62], 1)                                                                                                                                                                            {}
call_module    l__self___features_denseblock3_denselayer15_norm1  L__self___features_denseblock3_denselayer15_norm1           (concated_features_32,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer15_relu1  L__self___features_denseblock3_denselayer15_relu1           (l__self___features_denseblock3_denselayer15_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_64                               L__self___features_denseblock3_denselayer15_conv1           (l__self___features_denseblock3_denselayer15_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer15_norm2  L__self___features_denseblock3_denselayer15_norm2           (bottleneck_output_64,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer15_relu2  L__self___features_denseblock3_denselayer15_relu2           (l__self___features_denseblock3_denselayer15_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_64                                    L__self___features_denseblock3_denselayer15_conv2           (l__self___features_denseblock3_denselayer15_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_33                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64], 1)                                                                                                                                                           {}
call_module    l__self___features_denseblock3_denselayer16_norm1  L__self___features_denseblock3_denselayer16_norm1           (concated_features_33,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer16_relu1  L__self___features_denseblock3_denselayer16_relu1           (l__self___features_denseblock3_denselayer16_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_66                               L__self___features_denseblock3_denselayer16_conv1           (l__self___features_denseblock3_denselayer16_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer16_norm2  L__self___features_denseblock3_denselayer16_norm2           (bottleneck_output_66,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer16_relu2  L__self___features_denseblock3_denselayer16_relu2           (l__self___features_denseblock3_denselayer16_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_66                                    L__self___features_denseblock3_denselayer16_conv2           (l__self___features_denseblock3_denselayer16_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_34                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66], 1)                                                                                                                                          {}
call_module    l__self___features_denseblock3_denselayer17_norm1  L__self___features_denseblock3_denselayer17_norm1           (concated_features_34,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer17_relu1  L__self___features_denseblock3_denselayer17_relu1           (l__self___features_denseblock3_denselayer17_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_68                               L__self___features_denseblock3_denselayer17_conv1           (l__self___features_denseblock3_denselayer17_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer17_norm2  L__self___features_denseblock3_denselayer17_norm2           (bottleneck_output_68,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer17_relu2  L__self___features_denseblock3_denselayer17_relu2           (l__self___features_denseblock3_denselayer17_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_68                                    L__self___features_denseblock3_denselayer17_conv2           (l__self___features_denseblock3_denselayer17_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_35                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68], 1)                                                                                                                         {}
call_module    l__self___features_denseblock3_denselayer18_norm1  L__self___features_denseblock3_denselayer18_norm1           (concated_features_35,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer18_relu1  L__self___features_denseblock3_denselayer18_relu1           (l__self___features_denseblock3_denselayer18_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_70                               L__self___features_denseblock3_denselayer18_conv1           (l__self___features_denseblock3_denselayer18_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer18_norm2  L__self___features_denseblock3_denselayer18_norm2           (bottleneck_output_70,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer18_relu2  L__self___features_denseblock3_denselayer18_relu2           (l__self___features_denseblock3_denselayer18_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_70                                    L__self___features_denseblock3_denselayer18_conv2           (l__self___features_denseblock3_denselayer18_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_36                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70], 1)                                                                                                        {}
call_module    l__self___features_denseblock3_denselayer19_norm1  L__self___features_denseblock3_denselayer19_norm1           (concated_features_36,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer19_relu1  L__self___features_denseblock3_denselayer19_relu1           (l__self___features_denseblock3_denselayer19_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_72                               L__self___features_denseblock3_denselayer19_conv1           (l__self___features_denseblock3_denselayer19_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer19_norm2  L__self___features_denseblock3_denselayer19_norm2           (bottleneck_output_72,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer19_relu2  L__self___features_denseblock3_denselayer19_relu2           (l__self___features_denseblock3_denselayer19_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_72                                    L__self___features_denseblock3_denselayer19_conv2           (l__self___features_denseblock3_denselayer19_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_37                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72], 1)                                                                                       {}
call_module    l__self___features_denseblock3_denselayer20_norm1  L__self___features_denseblock3_denselayer20_norm1           (concated_features_37,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer20_relu1  L__self___features_denseblock3_denselayer20_relu1           (l__self___features_denseblock3_denselayer20_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_74                               L__self___features_denseblock3_denselayer20_conv1           (l__self___features_denseblock3_denselayer20_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer20_norm2  L__self___features_denseblock3_denselayer20_norm2           (bottleneck_output_74,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer20_relu2  L__self___features_denseblock3_denselayer20_relu2           (l__self___features_denseblock3_denselayer20_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_74                                    L__self___features_denseblock3_denselayer20_conv2           (l__self___features_denseblock3_denselayer20_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_38                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74], 1)                                                                      {}
call_module    l__self___features_denseblock3_denselayer21_norm1  L__self___features_denseblock3_denselayer21_norm1           (concated_features_38,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer21_relu1  L__self___features_denseblock3_denselayer21_relu1           (l__self___features_denseblock3_denselayer21_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_76                               L__self___features_denseblock3_denselayer21_conv1           (l__self___features_denseblock3_denselayer21_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer21_norm2  L__self___features_denseblock3_denselayer21_norm2           (bottleneck_output_76,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer21_relu2  L__self___features_denseblock3_denselayer21_relu2           (l__self___features_denseblock3_denselayer21_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_76                                    L__self___features_denseblock3_denselayer21_conv2           (l__self___features_denseblock3_denselayer21_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_39                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76], 1)                                                     {}
call_module    l__self___features_denseblock3_denselayer22_norm1  L__self___features_denseblock3_denselayer22_norm1           (concated_features_39,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer22_relu1  L__self___features_denseblock3_denselayer22_relu1           (l__self___features_denseblock3_denselayer22_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_78                               L__self___features_denseblock3_denselayer22_conv1           (l__self___features_denseblock3_denselayer22_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer22_norm2  L__self___features_denseblock3_denselayer22_norm2           (bottleneck_output_78,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer22_relu2  L__self___features_denseblock3_denselayer22_relu2           (l__self___features_denseblock3_denselayer22_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_78                                    L__self___features_denseblock3_denselayer22_conv2           (l__self___features_denseblock3_denselayer22_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_40                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78], 1)                                    {}
call_module    l__self___features_denseblock3_denselayer23_norm1  L__self___features_denseblock3_denselayer23_norm1           (concated_features_40,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer23_relu1  L__self___features_denseblock3_denselayer23_relu1           (l__self___features_denseblock3_denselayer23_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_80                               L__self___features_denseblock3_denselayer23_conv1           (l__self___features_denseblock3_denselayer23_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer23_norm2  L__self___features_denseblock3_denselayer23_norm2           (bottleneck_output_80,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer23_relu2  L__self___features_denseblock3_denselayer23_relu2           (l__self___features_denseblock3_denselayer23_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_80                                    L__self___features_denseblock3_denselayer23_conv2           (l__self___features_denseblock3_denselayer23_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_41                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78, new_features_80], 1)                   {}
call_module    l__self___features_denseblock3_denselayer24_norm1  L__self___features_denseblock3_denselayer24_norm1           (concated_features_41,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer24_relu1  L__self___features_denseblock3_denselayer24_relu1           (l__self___features_denseblock3_denselayer24_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_82                               L__self___features_denseblock3_denselayer24_conv1           (l__self___features_denseblock3_denselayer24_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock3_denselayer24_norm2  L__self___features_denseblock3_denselayer24_norm2           (bottleneck_output_82,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock3_denselayer24_relu2  L__self___features_denseblock3_denselayer24_relu2           (l__self___features_denseblock3_denselayer24_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_82                                    L__self___features_denseblock3_denselayer24_conv2           (l__self___features_denseblock3_denselayer24_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  cat_44                                             <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition2_pool, new_features_36, new_features_38, new_features_40, new_features_42, new_features_44, new_features_46, new_features_48, new_features_50, new_features_52, new_features_54, new_features_56, new_features_58, new_features_60, new_features_62, new_features_64, new_features_66, new_features_68, new_features_70, new_features_72, new_features_74, new_features_76, new_features_78, new_features_80, new_features_82], 1)  {}
call_module    l__self___features_transition3_norm                L__self___features_transition3_norm                         (cat_44,)                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_module    l__self___features_transition3_relu                L__self___features_transition3_relu                         (l__self___features_transition3_norm,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition3_conv                L__self___features_transition3_conv                         (l__self___features_transition3_relu,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_module    l__self___features_transition3_pool                L__self___features_transition3_pool                         (l__self___features_transition3_conv,)                                                                                                                                                                                                                                                                                                                                                                                                                              {}
call_function  concated_features_42                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool], 1)                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock4_denselayer1_norm1   L__self___features_denseblock4_denselayer1_norm1            (concated_features_42,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer1_relu1   L__self___features_denseblock4_denselayer1_relu1            (l__self___features_denseblock4_denselayer1_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_84                               L__self___features_denseblock4_denselayer1_conv1            (l__self___features_denseblock4_denselayer1_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer1_norm2   L__self___features_denseblock4_denselayer1_norm2            (bottleneck_output_84,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer1_relu2   L__self___features_denseblock4_denselayer1_relu2            (l__self___features_denseblock4_denselayer1_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_84                                    L__self___features_denseblock4_denselayer1_conv2            (l__self___features_denseblock4_denselayer1_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_43                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84], 1)                                                                                                                                                                                                                                                                                                                                                                                                         {}
call_module    l__self___features_denseblock4_denselayer2_norm1   L__self___features_denseblock4_denselayer2_norm1            (concated_features_43,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer2_relu1   L__self___features_denseblock4_denselayer2_relu1            (l__self___features_denseblock4_denselayer2_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_86                               L__self___features_denseblock4_denselayer2_conv1            (l__self___features_denseblock4_denselayer2_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer2_norm2   L__self___features_denseblock4_denselayer2_norm2            (bottleneck_output_86,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer2_relu2   L__self___features_denseblock4_denselayer2_relu2            (l__self___features_denseblock4_denselayer2_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_86                                    L__self___features_denseblock4_denselayer2_conv2            (l__self___features_denseblock4_denselayer2_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_44                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86], 1)                                                                                                                                                                                                                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer3_norm1   L__self___features_denseblock4_denselayer3_norm1            (concated_features_44,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer3_relu1   L__self___features_denseblock4_denselayer3_relu1            (l__self___features_denseblock4_denselayer3_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_88                               L__self___features_denseblock4_denselayer3_conv1            (l__self___features_denseblock4_denselayer3_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer3_norm2   L__self___features_denseblock4_denselayer3_norm2            (bottleneck_output_88,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer3_relu2   L__self___features_denseblock4_denselayer3_relu2            (l__self___features_denseblock4_denselayer3_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_88                                    L__self___features_denseblock4_denselayer3_conv2            (l__self___features_denseblock4_denselayer3_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_45                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88], 1)                                                                                                                                                                                                                                                                                                                                                                       {}
call_module    l__self___features_denseblock4_denselayer4_norm1   L__self___features_denseblock4_denselayer4_norm1            (concated_features_45,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer4_relu1   L__self___features_denseblock4_denselayer4_relu1            (l__self___features_denseblock4_denselayer4_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_90                               L__self___features_denseblock4_denselayer4_conv1            (l__self___features_denseblock4_denselayer4_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer4_norm2   L__self___features_denseblock4_denselayer4_norm2            (bottleneck_output_90,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer4_relu2   L__self___features_denseblock4_denselayer4_relu2            (l__self___features_denseblock4_denselayer4_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_90                                    L__self___features_denseblock4_denselayer4_conv2            (l__self___features_denseblock4_denselayer4_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_46                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90], 1)                                                                                                                                                                                                                                                                                                                                                      {}
call_module    l__self___features_denseblock4_denselayer5_norm1   L__self___features_denseblock4_denselayer5_norm1            (concated_features_46,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer5_relu1   L__self___features_denseblock4_denselayer5_relu1            (l__self___features_denseblock4_denselayer5_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_92                               L__self___features_denseblock4_denselayer5_conv1            (l__self___features_denseblock4_denselayer5_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer5_norm2   L__self___features_denseblock4_denselayer5_norm2            (bottleneck_output_92,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer5_relu2   L__self___features_denseblock4_denselayer5_relu2            (l__self___features_denseblock4_denselayer5_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_92                                    L__self___features_denseblock4_denselayer5_conv2            (l__self___features_denseblock4_denselayer5_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_47                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92], 1)                                                                                                                                                                                                                                                                                                                                     {}
call_module    l__self___features_denseblock4_denselayer6_norm1   L__self___features_denseblock4_denselayer6_norm1            (concated_features_47,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer6_relu1   L__self___features_denseblock4_denselayer6_relu1            (l__self___features_denseblock4_denselayer6_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_94                               L__self___features_denseblock4_denselayer6_conv1            (l__self___features_denseblock4_denselayer6_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer6_norm2   L__self___features_denseblock4_denselayer6_norm2            (bottleneck_output_94,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer6_relu2   L__self___features_denseblock4_denselayer6_relu2            (l__self___features_denseblock4_denselayer6_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_94                                    L__self___features_denseblock4_denselayer6_conv2            (l__self___features_denseblock4_denselayer6_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_48                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94], 1)                                                                                                                                                                                                                                                                                                                    {}
call_module    l__self___features_denseblock4_denselayer7_norm1   L__self___features_denseblock4_denselayer7_norm1            (concated_features_48,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer7_relu1   L__self___features_denseblock4_denselayer7_relu1            (l__self___features_denseblock4_denselayer7_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_96                               L__self___features_denseblock4_denselayer7_conv1            (l__self___features_denseblock4_denselayer7_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer7_norm2   L__self___features_denseblock4_denselayer7_norm2            (bottleneck_output_96,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer7_relu2   L__self___features_denseblock4_denselayer7_relu2            (l__self___features_denseblock4_denselayer7_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_96                                    L__self___features_denseblock4_denselayer7_conv2            (l__self___features_denseblock4_denselayer7_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_49                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96], 1)                                                                                                                                                                                                                                                                                                   {}
call_module    l__self___features_denseblock4_denselayer8_norm1   L__self___features_denseblock4_denselayer8_norm1            (concated_features_49,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer8_relu1   L__self___features_denseblock4_denselayer8_relu1            (l__self___features_denseblock4_denselayer8_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_98                               L__self___features_denseblock4_denselayer8_conv1            (l__self___features_denseblock4_denselayer8_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer8_norm2   L__self___features_denseblock4_denselayer8_norm2            (bottleneck_output_98,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer8_relu2   L__self___features_denseblock4_denselayer8_relu2            (l__self___features_denseblock4_denselayer8_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_98                                    L__self___features_denseblock4_denselayer8_conv2            (l__self___features_denseblock4_denselayer8_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_50                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98], 1)                                                                                                                                                                                                                                                                                  {}
call_module    l__self___features_denseblock4_denselayer9_norm1   L__self___features_denseblock4_denselayer9_norm1            (concated_features_50,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer9_relu1   L__self___features_denseblock4_denselayer9_relu1            (l__self___features_denseblock4_denselayer9_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    bottleneck_output_100                              L__self___features_denseblock4_denselayer9_conv1            (l__self___features_denseblock4_denselayer9_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    l__self___features_denseblock4_denselayer9_norm2   L__self___features_denseblock4_denselayer9_norm2            (bottleneck_output_100,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer9_relu2   L__self___features_denseblock4_denselayer9_relu2            (l__self___features_denseblock4_denselayer9_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_module    new_features_100                                   L__self___features_denseblock4_denselayer9_conv2            (l__self___features_denseblock4_denselayer9_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                 {}
call_function  concated_features_51                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100], 1)                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer10_norm1  L__self___features_denseblock4_denselayer10_norm1           (concated_features_51,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer10_relu1  L__self___features_denseblock4_denselayer10_relu1           (l__self___features_denseblock4_denselayer10_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_102                              L__self___features_denseblock4_denselayer10_conv1           (l__self___features_denseblock4_denselayer10_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer10_norm2  L__self___features_denseblock4_denselayer10_norm2           (bottleneck_output_102,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer10_relu2  L__self___features_denseblock4_denselayer10_relu2           (l__self___features_denseblock4_denselayer10_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_102                                   L__self___features_denseblock4_denselayer10_conv2           (l__self___features_denseblock4_denselayer10_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_52                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102], 1)                                                                                                                                                                                                                                              {}
call_module    l__self___features_denseblock4_denselayer11_norm1  L__self___features_denseblock4_denselayer11_norm1           (concated_features_52,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer11_relu1  L__self___features_denseblock4_denselayer11_relu1           (l__self___features_denseblock4_denselayer11_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_104                              L__self___features_denseblock4_denselayer11_conv1           (l__self___features_denseblock4_denselayer11_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer11_norm2  L__self___features_denseblock4_denselayer11_norm2           (bottleneck_output_104,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer11_relu2  L__self___features_denseblock4_denselayer11_relu2           (l__self___features_denseblock4_denselayer11_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_104                                   L__self___features_denseblock4_denselayer11_conv2           (l__self___features_denseblock4_denselayer11_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_53                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104], 1)                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer12_norm1  L__self___features_denseblock4_denselayer12_norm1           (concated_features_53,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer12_relu1  L__self___features_denseblock4_denselayer12_relu1           (l__self___features_denseblock4_denselayer12_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_106                              L__self___features_denseblock4_denselayer12_conv1           (l__self___features_denseblock4_denselayer12_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer12_norm2  L__self___features_denseblock4_denselayer12_norm2           (bottleneck_output_106,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer12_relu2  L__self___features_denseblock4_denselayer12_relu2           (l__self___features_denseblock4_denselayer12_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_106                                   L__self___features_denseblock4_denselayer12_conv2           (l__self___features_denseblock4_denselayer12_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_54                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106], 1)                                                                                                                                                                                                          {}
call_module    l__self___features_denseblock4_denselayer13_norm1  L__self___features_denseblock4_denselayer13_norm1           (concated_features_54,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer13_relu1  L__self___features_denseblock4_denselayer13_relu1           (l__self___features_denseblock4_denselayer13_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_108                              L__self___features_denseblock4_denselayer13_conv1           (l__self___features_denseblock4_denselayer13_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer13_norm2  L__self___features_denseblock4_denselayer13_norm2           (bottleneck_output_108,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer13_relu2  L__self___features_denseblock4_denselayer13_relu2           (l__self___features_denseblock4_denselayer13_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_108                                   L__self___features_denseblock4_denselayer13_conv2           (l__self___features_denseblock4_denselayer13_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_55                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108], 1)                                                                                                                                                                                        {}
call_module    l__self___features_denseblock4_denselayer14_norm1  L__self___features_denseblock4_denselayer14_norm1           (concated_features_55,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer14_relu1  L__self___features_denseblock4_denselayer14_relu1           (l__self___features_denseblock4_denselayer14_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_110                              L__self___features_denseblock4_denselayer14_conv1           (l__self___features_denseblock4_denselayer14_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer14_norm2  L__self___features_denseblock4_denselayer14_norm2           (bottleneck_output_110,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer14_relu2  L__self___features_denseblock4_denselayer14_relu2           (l__self___features_denseblock4_denselayer14_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_110                                   L__self___features_denseblock4_denselayer14_conv2           (l__self___features_denseblock4_denselayer14_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_56                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110], 1)                                                                                                                                                                      {}
call_module    l__self___features_denseblock4_denselayer15_norm1  L__self___features_denseblock4_denselayer15_norm1           (concated_features_56,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer15_relu1  L__self___features_denseblock4_denselayer15_relu1           (l__self___features_denseblock4_denselayer15_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_112                              L__self___features_denseblock4_denselayer15_conv1           (l__self___features_denseblock4_denselayer15_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer15_norm2  L__self___features_denseblock4_denselayer15_norm2           (bottleneck_output_112,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer15_relu2  L__self___features_denseblock4_denselayer15_relu2           (l__self___features_denseblock4_denselayer15_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_112                                   L__self___features_denseblock4_denselayer15_conv2           (l__self___features_denseblock4_denselayer15_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  concated_features_57                               <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110, new_features_112], 1)                                                                                                                                                    {}
call_module    l__self___features_denseblock4_denselayer16_norm1  L__self___features_denseblock4_denselayer16_norm1           (concated_features_57,)                                                                                                                                                                                                                                                                                                                                                                                                                                             {}
call_module    l__self___features_denseblock4_denselayer16_relu1  L__self___features_denseblock4_denselayer16_relu1           (l__self___features_denseblock4_denselayer16_norm1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    bottleneck_output_114                              L__self___features_denseblock4_denselayer16_conv1           (l__self___features_denseblock4_denselayer16_relu1,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    l__self___features_denseblock4_denselayer16_norm2  L__self___features_denseblock4_denselayer16_norm2           (bottleneck_output_114,)                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
call_module    l__self___features_denseblock4_denselayer16_relu2  L__self___features_denseblock4_denselayer16_relu2           (l__self___features_denseblock4_denselayer16_norm2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_module    new_features_114                                   L__self___features_denseblock4_denselayer16_conv2           (l__self___features_denseblock4_denselayer16_relu2,)                                                                                                                                                                                                                                                                                                                                                                                                                {}
call_function  cat_61                                             <built-in method cat of type object at 0x7f84532e6760>      ([l__self___features_transition3_pool, new_features_84, new_features_86, new_features_88, new_features_90, new_features_92, new_features_94, new_features_96, new_features_98, new_features_100, new_features_102, new_features_104, new_features_106, new_features_108, new_features_110, new_features_112, new_features_114], 1)                                                                                                                                  {}
call_module    features                                           L__self___features_norm5                                    (cat_61,)                                                                                                                                                                                                                                                                                                                                                                                                                                                           {}
call_function  out                                                <function relu at 0x7f837b431510>                           (features,)                                                                                                                                                                                                                                                                                                                                                                                                                                                         {'inplace': True}
call_function  out_1                                              <function adaptive_avg_pool2d at 0x7f837b431000>            (out, (1, 1))                                                                                                                                                                                                                                                                                                                                                                                                                                                       {}
call_function  out_2                                              <built-in method flatten of type object at 0x7f84532e6760>  (out_1, 1)                                                                                                                                                                                                                                                                                                                                                                                                                                                          {}
call_module    out_3                                              L__self___classifier                                        (out_2,)                                                                                                                                                                                                                                                                                                                                                                                                                                                            {}
output         output                                             output                                                      ((out_3,),)                                                                                                                                                                                                                                                                                                                                                                                                                                                         {}

tensor([[ 0.0614, -0.4023, -0.2792,  ..., -0.5549,  0.0976, -0.0634],
        [-0.2032, -0.2706, -0.0935,  ..., -0.4815,  0.0758, -0.1038],
        [ 0.0637, -0.3492, -0.1492,  ..., -0.4841,  0.1776, -0.0723],
        ...,
        [-0.1050, -0.3393,  0.0092,  ..., -0.4862,  0.0555, -0.1058],
        [ 0.0018, -0.2431, -0.1656,  ..., -0.5072,  0.0977, -0.1387],
        [ 0.1192, -0.3563, -0.1147,  ..., -0.4839,  0.1770, -0.0659]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line if b.sum() < 0 is the source of data-dependent control flow.

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
custom backend called with FX graph:
opcode         name    target                                                  args         kwargs
-------------  ------  ------------------------------------------------------  -----------  --------
placeholder    l_a_    L_a_                                                    ()           {}
placeholder    l_b_    L_b_                                                    ()           {}
call_function  abs_1   <built-in method abs of type object at 0x7f84532e6760>  (l_a_,)      {}
call_function  add     <built-in function add>                                 (abs_1, 1)   {}
call_function  x       <built-in function truediv>                             (l_a_, add)  {}
call_method    sum_1   sum                                                     (l_b_,)      {}
call_function  lt      <built-in function lt>                                  (sum_1, 0)   {}
output         output  output                                                  ((x, lt),)   {}
custom backend called with FX graph:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_x_    L_x_                     ()            {}
placeholder    l_b_    L_b_                     ()            {}
call_function  mul     <built-in function mul>  (l_x_, l_b_)  {}
output         output  output                   ((mul,),)     {}
custom backend called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    l_b_    L_b_                     ()           {}
placeholder    l_x_    L_x_                     ()           {}
call_function  b       <built-in function mul>  (l_b_, -1)   {}
call_function  mul_1   <built-in function mul>  (l_x_, b)    {}
output         output  output                   ((mul_1,),)  {}

tensor([-0.0176,  1.0753,  0.0282,  0.0756, -0.0176,  0.0633, -0.9161,  0.1333,
        -0.1971, -0.3406])

The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):

  1. x = a / (torch.abs(a) + 1)

  2. b = b * -1; return x * b

  3. return x * b

When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.

Let’s investigate by example how TorchDynamo would step through bar. If b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if not b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using torch._dynamo.explain:

# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)
Graph Count: 2
Graph Break Count: 1
Op Count: 6
Break Reasons:
  Break Reason 1:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file /var/lib/workspace/intermediate_source/torch_compile_tutorial.py, line 434 in bar>
Ops per Graph:
  Ops 1:
    <built-in method abs of type object at 0x7f84532e6760>
    <built-in function add>
    <built-in function truediv>
    <built-in function lt>
  Ops 2:
    <built-in function mul>
    <built-in function mul>
Out Guards:
  Guard 1:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: "G['torch']"
    Source: global
    Create Function: FUNCTION_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 4:
    Name: "G['torch'].abs"
    Source: global
    Create Function: FUNCTION_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 5:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 6:
    Name: "L['b']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f82b2cd12b0; dead>
    Guarded Class Weakref: <weakref at 0x7f837bac7c90; to 'torch._C._TensorMeta' at 0x58585a0 (Tensor)>
  Guard 7:
    Name: ''
    Source: global
    Create Function: BACKEND_MATCH
    Guard Types: ['BACKEND_MATCH']
    Code List: ['___check_current_backend(140197382899168)']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 8:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 9:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 10:
    Name: "L['a']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['a'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f82eb48d760; dead>
    Guarded Class Weakref: <weakref at 0x7f837bac7c90; to 'torch._C._TensorMeta' at 0x58585a0 (Tensor)>
  Guard 11:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 12:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 13:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 14:
    Name: ''
    Source: global
    Create Function: BACKEND_MATCH
    Guard Types: ['BACKEND_MATCH']
    Code List: ['___check_current_backend(140197382899168)']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 15:
    Name: "L['b']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['b'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f82b2cd12b0; dead>
    Guarded Class Weakref: <weakref at 0x7f837bac7c90; to 'torch._C._TensorMeta' at 0x58585a0 (Tensor)>
  Guard 16:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 17:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f82b2f727a0; dead>
    Guarded Class Weakref: <weakref at 0x7f837bac7c90; to 'torch._C._TensorMeta' at 0x58585a0 (Tensor)>
  Guard 18:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
Compile Times: TorchDynamo compilation metrics:
Function                         Runtimes (s)
-------------------------------  --------------
_compile.<locals>.compile_inner  0.0110, 0.0066
OutputGraph.call_user_compiler   0.0001, 0.0000

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using fullgraph=True:

opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 482, in <module>
    opt_bar(torch.randn(10), torch.randn(10))
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/opt/conda/envs/py_3.10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 464, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 434, in bar
    if b.sum() < 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.

opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))
tensor([[ 0.1344,  0.1976, -0.2056,  ...,  0.1861, -0.1632, -0.1642],
        [ 0.2680,  0.2349, -0.1903,  ...,  0.2170, -0.0055,  0.0309],
        [ 0.0164,  0.2180, -0.1115,  ...,  0.1709, -0.1681, -0.0635],
        ...,
        [ 0.0807,  0.2680, -0.1889,  ...,  0.0381, -0.2073, -0.1444],
        [-0.0209,  0.0858, -0.2457,  ...,  0.1862, -0.1280, -0.0282],
        [-0.0386,  0.0730, -0.1961,  ...,  0.0864, -0.2199, -0.1485]],
       device='cuda:0', grad_fn=<CompiledFunctionBackward>)

We can use torch.export (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program. The exported graph is intended to be run on different (i.e. Python-less) environments. One important restriction is that the torch.export does not support graph breaks. Please check this tutorial for more details on torch.export.

Conclusion

In this tutorial, we introduced torch.compile by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give torch.compile a try!

Total running time of the script: ( 6 minutes 10.264 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources