Gradient on parts of feature maps


I need some advice on how to best approach this problem. I want to compare feature maps between a network with K-classes. For simplicity, lets presume the feature maps (with batch size 1) are each a vector/tensor with size (1, 1, K).

I only want to compare and update depending on single values, which will change their position in K, depending on the input. If I slice the feature map and set grad=False, It will set the entire gradient to false. Basically I am looking for something like an “ignore index”, where I can train only on the result of e.g. (1, 1, K=4), not updating for the rest of K-4 values.

Id appreciate any help and pointers :slight_smile:

Best Regards,