Performance of `torch.compile` is significantly slowed down under `torch.inference_mode`

I have observed that when using torch.compile to optimize a model, the performance significantly degrades during inference under torch.inference_mode . In fact, it is even worse than the performance of the non-optimized model. However, when I place the compilation process within the context of torch.inference_mode , the performance issues are resolved.

My resultI tested in torch==2.1.0+cu118 on an A10 GPU:

# 1. both are disabled, speed up 2.5X, it's great!
enable inference mode when compile: False
enable inference mode when benchmark: False
not compiled latency: 0.002047449684143066
compiled latency: 0.0008150367736816407, speed up: 2.5120948529652662

# 2.  both are enabled, speed up 2.35X, great work!
enable inference mode when compile: True
enable inference mode when benchmark: True
not compiled latency: 0.0019193056106567383
compiled latency: 0.0008156064033508301, speed up: 2.3532252846121366

# 3. enable torch.inference_mode when benchmark, disabled when compiling
# the compiled model is slower than naive torch model ! help me!
# and the naive model's latency is still ~0.002
enable inference mode when compile: False
enable inference mode when benchmark: True
not compiled latency: 0.0019556352615356445
compiled latency: 0.010692658996582031, speed up: 0.18289513040309005

# 4. disable in benchmark and enable in compiling, the result is also disappointed
enable inference mode when compile: True
enable inference mode when benchmark: False
not compiled latency: 0.0020351423263549806
compiled latency: 0.011507730865478516, speed up: 0.17685001067066186

And here is the code to reproduce my result,

import torch
from torchvision.models import resnet18
import sys
from contextlib import nullcontext

# enable inference mode when compile or benchmark?
enable_when_compile = sys.argv[1] == "true"
enable_when_benchmark = sys.argv[2] == "true"


def _benchmark(
    iters,
    f,
    context,   # with torch.inference or not
    *args,
    **kwargs,
) -> float:
    """Estimates the average time duration for a single inference call in second
    Returns:
        estimated average time duration in second for a single inference call
    """
    with context():
        f(*args, **kwargs)
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    with context():
        start_event.record()
        for _ in range(iters):
            f(*args, **kwargs)
        end_event.record()
    torch.cuda.synchronize()
    elapsed_time_s = start_event.elapsed_time(end_event) * 1.0e-3
    avg_time_s = elapsed_time_s / iters
    print("Estimated average time duration: {:.6f} s".format(avg_time_s))
    return avg_time_s


class BenchmarkRunner(object):
    def __init__(self, use_inference_mode: bool):
        self.context = nullcontext if not use_inference_mode else torch.inference_mode

    def __call__(self, iters, f, *args, **kwargs) -> float:
        return _benchmark(iters, f, self.context, *args, **kwargs)


@torch.no_grad()
def run():
    input = [torch.rand(8, 3, 224, 224).to(torch.device("cuda"), dtype=torch.float16)]
    net = resnet18(pretrained=False).cuda().half()
    net.eval()
    context = nullcontext if not enable_when_compile else torch.inference_mode
    compiled = torch.compile(net, mode="reduce-overhead", backend="inductor")
    with context():
        _ = compiled(*input)

    latency_compiled = BenchmarkRunner(enable_when_benchmark)(10, compiled, *input)
    latency_torch = BenchmarkRunner(enable_when_benchmark)(10, net, *input)

    print(f"enable inference mode when compile: {enable_when_compile}")
    print(f"enable inference mode when benchmark: {enable_when_benchmark}")
    print(f"not compiled latency: {latency_torch}")
    print(f"compiled latency: {latency_compiled}, speed up: {latency_torch / latency_compiled}")

if __name__ == "__main__":
    run()

This seems like a bug - @Crazyai can you file a github issue with your repro instructions? Sign in to GitHub · GitHub

As a workaround - you can replace torch.inference_mode with torch.no_grad. For some context:

(1) I tried your repro with torch.no_grad() and I see the expected speedups with torch.compile:

enable inference mode when compile: True
enable inference mode when benchmark: False
not compiled latency: 0.013454794311523438
compiled latency: 0.0005569119930267334, speed up: 24.159641882371112
...
enable inference mode when compile: False
enable inference mode when benchmark: True
not compiled latency: 0.012717078399658202
compiled latency: 0.0005573919773101807, speed up: 22.81532371712148

(2) It’s worth calling out what the difference between inference_mode and no_grad, which is a bit subtle.

tldr: inference_mode has the additional benefit of removing some cpu overhead in eager mode. When using torch.compile, the difference between inference_mode and no_grad is more negligible - torch.compile itself is already doing compilation to avoid cpu overhead as much as possible.

In the end state, with torch.compile, no_grad and inference_mode should probably be equivalent (both in terms of what they do and what their perf implications are). But if there are any bugs with torch.compile and inference_mode, we should fix them.

In eager mode, no_grad can be used to disable gradient calculation (and for example, avoid saving any activations for the backward if you only care about inference). However, even when you run no_grad, there is still some autograd-related overhead that occurs in eager-mode: autograd will perform view_tracking and version_counter tracking, which has some cpu-side overhead associated with it.

When you torch.compile a model, the resulting generated code that we execute at runtime is (90% of the time) generated cpp code, or generated triton code. None of this code is seen by autograd, so there is no view_tracking and version_counter tracking that we need to be disabled. (The small caveat is that inductor can sometimes generate calls to pytorch operations, like aten.as_strided and aten.addmm, which will have a tiny bit of autograd overhead, that can go away if you’re using inference mode. This is much smaller though, than the amount of overhead you’ll see in eager mode, and is probably not noticeable).

@bdhirsh Thanks for the reply. I just posted an issue on Github see it here. Glad to contribute to PyTorch.

@bdhirsh I am curious why the speed up ratio of your code is more than 20X… What is the hardware you tested on?

And when I use torch.no_grad instead of torch.inference_mode, the results are still confusing:

enable inference mode when compile: True
enable inference mode when benchmark: False
not compiled latency: 0.0024103456497192384
compiled latency: 0.011797164916992188, speed up: 0.20431566962732445

The diff:

Strange, I’m benchmarking on an A100 using a recent nightly.

(1) Can you try running your benchmark (with torch.no_grad) on either a recent nightly to see if you still don’t see speedups?

(2) Most of the speedup here is probably coming from torch.compile using cudagraphs to remove most of the python overhead. If you use a nightly and run with mode="reduce-overhead", torch.compile should give a warning if it tried and failed to use cudagraphs, and also explain why.

(3) You can also try running with the pytorch profiler to see where the slowdown is coming from (if you’re seeing a 5x slowdown then the profiler will probably point to something pretty obvious if you can get a profile): torch.profiler — PyTorch 2.1 documentation

@bdhirsh A quick test after updated to torch2.1.1-cuda121, the speed up with torch.no_grad seems ok. But the speed up with torch.inference_mode is not fixed.