Due to certain data I want my network to only learn on losses that are lower than 1. For that I use the following code:
def train(img, label):
x = self.net(img)
loss_val = self.loss(x, label)
if loss_val < 1:
Is this correct? Because I have the feeling that the gradients are not reseting so as soon as a loss_val is under 1, all updates on the gradients are done, even the ones on which the osses were higher than 1.
How can I otherwise reset all the gradients from the optimiser?
The code looks alright. What do you mean by “the gradients are not resetting”?
Are you seeing valid gradients at the end of the iteration even though the loss was >=1?
I don’t really know how to check that.
My worry that the following is happening. I do an iteration where the loss is over 1 and I don’t do a backwards on the loss, so all tensors still hold their gradient trees from the iteration. Then, on the second iteration the gradient trees are then also changed according to the information of the new iteration (while still holding the information of the first iteration) and if the second iteration loss is then lower than 1, the gradients are calculated, but there is also the information about the first iteration that was actually lower than 1.
Each forward pass will create a new computation graph.
To double check your use case, you can print the gradients via:
for name, param in net.named_parameters():
You can print it at different places to see, if the gradients are valid or have been zeroed out.