Modifying a parameter in forward pass after getting Nan in the backward pass

Hi,

I hope I can explain my scenario well so I can get the correct answer.
I have a customized loss function say:

loss= function (op1, op2, op3,…, op3), where op is an operation.

After a couple of epochs, I get Nan in the gradient, but I know which operation, say op3, causes this Nan (SVDBackward operation in my case) and I know how to fix it mathematically.

Aside note:
1- if you’re wondering why I’m not simply using IF-Else statement since I know the problem and how to solve it. The answer is, the problem happens in the backward pass and cant be noticed from only the forward pass.

2- torch.nn.utils.clip_grad_norm_ will not work in my case.

My question is: is there a way to modify the op3 after the backward pass, i.e (loss.backward()?

Something like this:

loss.backward()
doing the magic code to modify op3 that will be executed once in this current forward 
pass..

@ptrblck could you advise on this?

I don’t know, how the function to modify the forward should work, but you could register a hook on the parameters, which create NaNs, check for invalid values during the backward pass, and set a flag which can be checked to modify your model.
After the faulty backward pass, you should zero out the gradients, check for the flag, and modify your model.

Would this workflow generally work?