During training, I found some spikes of loss at epoch as follows
Based on the loss of previous loss, if the loss different between current and previous too big, I will not update the gradient. Do we have any function in pytorch do check implement it? Currently, I used as
current_loss = cross_entropy(input, target)
if current_loss- prev_loss> threshold):
current_loss.backward()
prev_loss = current_loss