I have a model where, when I run only the forward pass, it is slower when I do set_grad_enabled(False)
than when I do set_grad_enabled(True)
. My intuition is that set_grad_enabled(False)
should only stop saving various activations and intermediate tensors for the backward pass but otherwise run the code exactly the same.
However, I’ve looked at the trace in nvidia’s nsight systems and I do see that different kernels are getting executed as a result of changing this flag (nothing else in my model is using this parameter to change code paths). Specifically, my linear layers take about 10x longer when disabling gradients.
When I enable gradients and I check the corresponding kernel that gets chosen for a particular linear layer, I see
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_64x3_tn
block : <<<128, 1, 1>>>
grid: <<<<17, 263, 2>>>
In the no grad case, looking at the same linear layer, I see the kernel
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f_stages_64x5_tn
block : <<<128, 1, 1>>>
grid: <<<<33, 1, 33600>>>.
Firstly, I don’t understand why set_enable_grad
context manager has any effect on which kernel gets launched. Secondly, it seems the kernel being chosen in the “no grad” case is launching way too many blocks and that’s probably why the it’s so much slower.
Some other info, this on an NVIDIA A100, torch==2.1.0+cu118, and I’m not using torch.compile
. I’ve tried to reproduce this by using only a single linear layer but I don’t observe the behaviour in that case. The tensors being multiplied are (1024, 2048) from the linear weights and then (2048, 1, 33600) activations