Torch compile: optimizer.step Generates Excessive Warning Messages

I’m experiencing an issue with excessive warning messages when using torch.compile() with optimizer step function.

While compiling the optimizer step provides a significant speed improvement during training (particularly with AdamW), I’m encountering challenges when using it with a learning rate scheduler.

Following the approach outlined in this discussion and the official PyTorch tutorial, I’ve wrapped the learning rate with a tensor to prevent recompilations. However, this results in thousands of warning messages being generated, such as:

('Grad tensors ["L['self'].param_groups[394]['params'][0].grad"] will be copied during cudagraphs execution. If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)

These warnings appear even when running the simple linear network example from the official tutorial.

Is there a recommended way to:

  1. Prevent these warnings from occurring?
  2. Suppress these specific warning messages?
  3. Handle this situation differently?

Any guidance would be greatly appreciated.

This looks CUDA graphs related, cc @mlazos and @Elias_Ellison.

1 Like

@hassonofer can you share your code?

This warning can be suppressed by calling torch._dynamo.decorators.mark_static_address on each of the grads of the parameters. This is a warning because when running the optimizer standalone with cudagraphs, you can have very bad perf without doing this due to host-device copies of the gradients for every parameter, which can be quite large.

1 Like

Thanks for your response.

I don’t have a minimal code example to share. It happens on my Birder project.
I’ll try to illustrate.
I’m, training a swin transformer v2 (small). Using torch.optim.AdamW and a torch.optim.lr_scheduler.CosineAnnealingLR scheduler.

The model is being compiled without any special arguments:
net = torch.compile(net)

While the optimizer step function set to compile without fullgraph (as documented):
optimizer.step = torch.compile(optimizer.step, fullgraph=False)

Note: the current main branch does not convert the learning rate to torch.tensor, but the logging happens either way. The optimizer defined here.

I think that the only thing which is not “standard” in this training code, is the way the parameter groups are being built. It’s a bit more complex in order to support features like layer-decay, custom weight decay per layer type, etc.
The function defined here while a specific key can be seen here.

I assume going through all of this is too much of a hassle :sweat_smile:, sorry about that…

I’ll try to create a straight forward code to reproduce it tomorrow (it’s kinda late here now).

Thanks a lot.

Alternatively you can just call torch._dynamo.decorators.mark_static_address on each of the .grad attributes of the parameters. This is the recommended way to remove the warning.

Does this mean, something like this ?

for p in model.parameters()
    mark_static_address(p)

Is it safe ? generally model parameters have static address ?

Thanks again

You should mark p.grad as static, you can verify that this is safe and the grads have the same address by printing p.grad.data_ptr() on each iter. This is generally safe across iterations.

1 Like

@mlazos this solution is quite awkward. Let’s imagine we want to compile our step function. We cannot mark our parameter’s gradients as static beforehand (they don’t exist) and we cannot place mark_static_address into our step function since it cannot be traced.

I actually submitted this PR to make it less spammy. It will now only show up if you set TORCH_LOGS=“perf_hints”

That’s great thank you!

However, I’d also love a solution to the issue I’ve pointed out - it seems pertinent that users have a way to get the best performance possible in an ergonomic/sane fashion.