backward() in a custom
autograd.Function does not get called by a python method directly (see also this blog entry regarding coverage tests), which is why breakpoints set within these
backward() methods never get triggered. Is there another way to use a python debugger for these backward passes?
What I usually do when I need this is to trigger it directly from inside the backward function with
import pdb; pdb.set_trace():
import torch from torch import autograd class F(autograd.Function): @staticmethod def forward(ctx, foo): return foo.clone() @staticmethod def backward(ctx, bar): import pdb; pdb.set_trace() return bar.clone() a = torch.tensor([3.2], requires_grad=True) b = F.apply(a) g = autograd.grad(b, a)
Thank you very much, I forgot about that solution, it is simple and effective!
I did not knew that Thanks for the tip!