How to solve the spikes of loss at some epoch?

During training, I found some spikes of loss at epoch as follows
image

Based on the loss of previous loss, if the loss different between current and previous too big, I will not update the gradient. Do we have any function in pytorch do check implement it? Currently, I used as


current_loss = cross_entropy(input, target)
if current_loss- prev_loss> threshold):
   current_loss.backward()
   prev_loss = current_loss

Exploding and Vanishing Gradients

If you suffer from exploding gradients, I recommend using clip_grad_norm_.

@Tony-Y: Thanks but how to know the low and up bound of grad to be input for the function

One good heuristic for setting this threshold is to look at statistics on the average norm over a sufficiently large number of updates. In our experiments we have noticed that for a given task and model size, training is not very sensitive to this hyper-parameter and the algorithm behaves well even for rather small thresholds.