I hope I can explain my scenario well so I can get the correct answer.
I have a customized loss function say:
loss= function (op1, op2, op3,…, op3), where op is an operation.
After a couple of epochs, I get Nan in the gradient, but I know which operation, say op3, causes this Nan (SVDBackward operation in my case) and I know how to fix it mathematically.
1- if you’re wondering why I’m not simply using IF-Else statement since I know the problem and how to solve it. The answer is, the problem happens in the backward pass and cant be noticed from only the forward pass.
2- torch.nn.utils.clip_grad_norm_ will not work in my case.
My question is: is there a way to modify the op3 after the backward pass, i.e (loss.backward()?
Something like this:
loss.backward() doing the magic code to modify op3 that will be executed once in this current forward pass..