Model forward slows down when disabling gradients

I have a model where, when I run only the forward pass, it is slower when I do set_grad_enabled(False) than when I do set_grad_enabled(True). My intuition is that set_grad_enabled(False) should only stop saving various activations and intermediate tensors for the backward pass but otherwise run the code exactly the same.

However, I’ve looked at the trace in nvidia’s nsight systems and I do see that different kernels are getting executed as a result of changing this flag (nothing else in my model is using this parameter to change code paths). Specifically, my linear layers take about 10x longer when disabling gradients.

When I enable gradients and I check the corresponding kernel that gets chosen for a particular linear layer, I see
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_64x3_tn
block : <<<128, 1, 1>>>
grid: <<<<17, 263, 2>>>

In the no grad case, looking at the same linear layer, I see the kernel
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f_stages_64x5_tn
block : <<<128, 1, 1>>>
grid: <<<<33, 1, 33600>>>.

Firstly, I don’t understand why set_enable_grad context manager has any effect on which kernel gets launched. Secondly, it seems the kernel being chosen in the “no grad” case is launching way too many blocks and that’s probably why the it’s so much slower.

Some other info, this on an NVIDIA A100, torch==2.1.0+cu118, and I’m not using torch.compile. I’ve tried to reproduce this by using only a single linear layer but I don’t observe the behaviour in that case. The tensors being multiplied are (1024, 2048) from the linear weights and then (2048, 1, 33600) activations

Could you post a minimal and executable code snippet reproducing the different kernel execution?

I’m working on this but so far I’m only able to repro it on my entire model which is complicated and spread across a bunch of files. But in the meantime, is there any insight as to what set_enable_grad is actually doing? It’s doing a lot more than I thought if it’s actually changing which kernels are getting launched

set_enable_grad(False) will disable creating the computation graph and will thus not store intermediate activations needed for the backward pass. It should not change any kernel selection, which is why we would need to get a reproducer to debug it further.

It is possible that grad mode changes the code path for matmuls - e10cfdd8953 regresses hf_Longformer-cuda-eager on pytorch 2024-01-16 · Issue #118548 · pytorch/pytorch · GitHub

To double check, you can run the following to see what ops are running underneath your linear:

from torch.testing._internal.logging_tensor import capture_logs_with_logging_tensor_mode

with capture_logs_with_logging_tensor_mode() as logs:
     <do stuff here>

print('\n'.join(logs))

Thanks @soulitzer! This seems like exactly what’s happening actually

In the case when I enable gradients, I get

...
$507: f16[33600, 1, 2056] = torch._ops.aten.constant_pad_nd.default($320, ['0', '0', '0', '0', '0', '0'], 0.0)
$508: f16[2100, 1, 2056] = torch._ops.aten.constant_pad_nd.default($506, ['0', '0', '0', '0', '0', '0'], 0.0)
$510: f16[4112, 2056] = torch._ops.aten._to_copy.default($509, dtype=torch.float16)
$511: f16[2056, 4112] = torch._ops.aten.t.default($510)
$512: f16[2100, 2056] = torch._ops.aten.view.default($508, ['2100', '2056'])
$513: f16[2100, 4112] = torch._ops.aten.mm.default($512, $511)
..

So it’s doing a matmul with (2100, 2056) times (2056, 4112).

When I disable gradients though, I get

...
$322: f16[33600, 1, 2056] = torch._ops.aten.constant_pad_nd.default($319, ['0', '0', '0', '0', '0', '0'], 0.0)
$323: f16[2100, 1, 2056] = torch._ops.aten.constant_pad_nd.default($321, ['0', '0', '0', '0', '0', '0'], 0.0)
$325: f16[4112, 2056] = torch._ops.aten._to_copy.default($324, dtype=torch.float16)
$326: f16[2056, 4112] = torch._ops.aten.t.default($325)
$327: f16[2100, 1, 2056] = torch._ops.aten.expand.default($323, ['2100', '1', '2056'])
$328: f16[2100, 1, 2056] = torch._ops.aten.view.default($327, ['2100', '1', '2056'])
$329: f16[2100, 2056, 4112] = torch._ops.aten.expand.default($326, ['2100', '2056', '4112'])
$330: f16[2100, 2056, 4112] = torch._ops.aten.view.default($329, ['2100', '2056', '4112'])
$331: f16[2100, 1, 4112] = torch._ops.aten.bmm.default($328, $330)
...

This is the same linear layer and I notice it inserts an expand so that it ends up doing a batched matmul of (2100, 1, 2056) with (2100, 2056, 4112). The end result is that it’s doing 2100 covector-matrix multiplies of (1, 2056) with (2056, 4112) which is less efficient than the first case of (2100, 2056) @ (2056, 4112).

Based on the linked comment, it seems like the second snippet is hitting case 3 and the first snippet is hitting case 2? Weirdly I can’t reproduce the issue with just a linear layer even when I try to ensure that the tensor isn’t contiguous.

This ends up happening a bunch of different places in the model so it slows down validation considerably. And sometimes that 2100 dimension is much much bigger making it even more inefficient.