Apparently the backward()
in a custom autograd.Function
does not get called by a python method directly (see also this blog entry regarding coverage tests), which is why breakpoints set within these backward()
methods never get triggered. Is there another way to use a python debugger for these backward passes?
Hi,
What I usually do when I need this is to trigger it directly from inside the backward function with import pdb; pdb.set_trace()
:
import torch
from torch import autograd
class F(autograd.Function):
@staticmethod
def forward(ctx, foo):
return foo.clone()
@staticmethod
def backward(ctx, bar):
import pdb; pdb.set_trace()
return bar.clone()
a = torch.tensor([3.2], requires_grad=True)
b = F.apply(a)
g = autograd.grad(b, a)[0]
1 Like
Thank you very much, I forgot about that solution, it is simple and effective!
By the way: In case you @albanD haven’t seen it yet, you can shorten that to breakpoint()
which is a built-in since python 3.7 - saves a little bit of typing:)
I did not knew that Thanks for the tip!