I want to do backward pass when the output is the logit 4. This is the relevant part of my code:
if output[:, 4] == True:
The if condition is not working, the rest is ok. Any ideas for me how to write this condition? Thank you in advance
I don’t quite understand
Do you want to compare the actual logit value against
4 or slice the output?
Hi @ptrblck , I am not sure I understand what slicing the output is. But I have up to 20 classes. The output is from 1 to 20
If the output is 4, then I want to do backward pass
Do you want to check if the logit for
class4 has the largest value and thus
class4 would be predicted?
If so, use
torch.argmax to check for the predicted class and use it then as a condition.