How to write if condition on a certain output

I want to do backward pass when the output is the logit 4. This is the relevant part of my code:

if output[:, 4] == True:
output[:, 4].backward()

The if condition is not working, the rest is ok. Any ideas for me how to write this condition? Thank you in advance

I don’t quite understand

Do you want to compare the actual logit value against 4 or slice the output?

Hi @ptrblck , I am not sure I understand what slicing the output is. But I have up to 20 classes. The output is from 1 to 20

If the output is 4, then I want to do backward pass

Do you want to check if the logit for class4 has the largest value and thus class4 would be predicted?
If so, use torch.argmax to check for the predicted class and use it then as a condition.

1 Like