Due to certain data I want my network to only learn on losses that are lower than 1. For that I use the following code:
def train(img, label):
x = self.net(img)
loss_val = self.loss(x, label)
self.optimizer.zero_grad()
if loss_val < 1:
loss_val.backward()
self.optimizer.step()
Is this correct? Because I have the feeling that the gradients are not reseting so as soon as a loss_val is under 1, all updates on the gradients are done, even the ones on which the osses were higher than 1.
How can I otherwise reset all the gradients from the optimiser?
The code looks alright. What do you mean by “the gradients are not resetting”?
Are you seeing valid gradients at the end of the iteration even though the loss was >=1?
My worry that the following is happening. I do an iteration where the loss is over 1 and I don’t do a backwards on the loss, so all tensors still hold their gradient trees from the iteration. Then, on the second iteration the gradient trees are then also changed according to the information of the new iteration (while still holding the information of the first iteration) and if the second iteration loss is then lower than 1, the gradients are calculated, but there is also the information about the first iteration that was actually lower than 1.