How to debug custom autograd.Function?

Apparently the backward() in a custom autograd.Function does not get called by a python method directly (see also this blog entry regarding coverage tests), which is why breakpoints set within these backward() methods never get triggered. Is there another way to use a python debugger for these backward passes?


What I usually do when I need this is to trigger it directly from inside the backward function with import pdb; pdb.set_trace():

import torch
from torch import autograd

class F(autograd.Function):
  def forward(ctx, foo):
    return foo.clone()

  def backward(ctx, bar):
    import pdb; pdb.set_trace()
    return bar.clone()

a = torch.tensor([3.2], requires_grad=True)
b = F.apply(a)
g = autograd.grad(b, a)[0]

1 Like

Thank you very much, I forgot about that solution, it is simple and effective!

By the way: In case you @albanD haven’t seen it yet, you can shorten that to breakpoint() which is a built-in since python 3.7 - saves a little bit of typing:)

I did not knew that :smiley: Thanks for the tip!