Why backward and optimizer.step become slower with jit.trace

Hi all! I am trying to use JIT to speed up my training. However, when profiling my code with JIT trace, I found that while zero_grad and forward are much more faster than the case without JIT, backward and optimizer.step become slower. Why this happened? I am also wondering whether JIT compiler has some optimizations for backward and optimizer.step. I tried to search on the Internet but found nothing about this.

Profiling result:

Code to reproduce this:

import torch
import torch.nn as nn

from IPython import embed

def profile(func):
    from line_profiler import LineProfiler

    def wrapper(*args, **kwargs):
        lp = LineProfiler()
        lp_wrapper = lp(func)
        result = lp_wrapper(*args, **kwargs)
        lp.print_stats()

        return result
    return wrapper


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32*64*64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 8),
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x, target):
        x = self.layers(x)
        return x, self.loss_fn(x, target)


@profile
def perf_train(model, inp, target):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()        

    import time
    for _ in range(50): # warmup
        optimizer.zero_grad()
        out, loss = model(inp, target)
        loss.backward()
        optimizer.step()
    start = time.time_ns()
    for _ in range(500):
        optimizer.zero_grad()
        out, loss = model(inp, target)
        loss.backward()
        optimizer.step()
    print(f'dur: {(time.time_ns() - start) / 1e6} ms')
    return out


def eager(perf_func):
    device = torch.device('cuda')

    mod = MyModule().to(device)
    inp = torch.randn(128, 3, 128, 128).to(device)
    target = torch.randint(0, 8, (128,)).to(inp.device)

    out = perf_func(mod, inp, target)
    
    # print(out.shape)


def trace(perf_func):
    device = torch.device('cuda')

    mod = MyModule().to(device)
    inp = torch.randn(128, 3, 128, 128).to(device)
    target = torch.randint(0, 8, (128,)).to(inp.device)

    traced = torch.jit.trace(mod, (inp, target))

    out = perf_func(traced, inp, target)
    
    # print(out.shape)

# run one of them each time
# eager(perf_train)
trace(perf_train)

It seems you are using the GPU for your model training and LineProfiler to profile the code. I doubt this profiler is aware of CUDA’s asynchronous execution and would claim the results might be invalid or did you correlate them to manual profiling via synchronized timers or via event based profiling?