Torch.nn.utils.clip_grad_norm_ slow performances

Good morning,

I implemented the network from the paper : focal loss
In numerous implementation they use the function

torch.nn.utils.clip_grad_norm_

To avoid gradient exploding and producing NaN.
However after monitoring performances of my implementation and of the different par of the algorithm, it seems that this function is taking as much time as the Full forward pass on the network for a full HD image.
By looking at the code Tensor are apparently moved to cpu and processed there.
Any feedback on this function and ideas to overcome this performances problem ?

Yours
Justin

3 Likes

did you figure this out?

1 Like