@swethmandava, you could register backward hook like:
def backward_hook(self, grad_input, grad_output):
for g in grad_input:
g[g != g] = 0 # replace all nan/inf in gradients to zero
model.register_backward_hook(backward_hook)
then it will work as similar as the warpctc_pytorch
. Please notify that this could distort the gradient direction as @tom mentioned.