I got 4D logit value matrices with the shape (batchsize, classes (e.g. 9 logit values if number of classes ==9), height, width) and the assigned class matrices (label matrices) which are 3D and has the shape (batchsize, 1 (which are the indices of the true class), height, width). How can the logits values corresponding to the assigned class accessed easily. Basically, use the label matrices as index for the logit matrices. The label matrix additionally contains default values which are 255 to exclude these pixels from gradient calculation. Hence, these pixels has to be filtered out first.
Thank you in advance!