I want to subdivide my batch of images into groups of true-positives, true-negatives, etc. according to some criterion (binary classification) which is not important for the purpose of this post. Lets say I have a tensor of predictions (the result of my forward pass) and labels of equal size
pred = torch.tensor([1, 0, 0, 1])
labels = torch.tensor([1, 1, 0, 0])
So the first entry is an example of true-positive, etc. Then I get the indices of every such instance like so
# transform into byte tensors for logical operations
pred = pred.byte()
labels = labels.byte()
# find indices where pred == 1 and pred == 0
true = pred.nonzero().squeeze()
false = (~pred).nonzero().squeeze()
# initialize all four tensors with according size to hold either 1 or 0 depending if the sample is in fact true-positive, etc. mutually exclusive across all four tensors
tp = torch.zeros_like(pred)
tn = torch.zeros_like(pred)
fn = torch.zeros_like(pred)
fp = torch.zeros_like(pred)
# set the entries to 1 at the respective index
tp[true[pred[true] == labels[true]]] = 1
fp[true[pred[true] != labels[true]]] = 1
tn[false[pred[false] == labels[false]]] = 1
fn[false[pred[false] != labels[false]]] = 1
So the result should read in agreement with the pred and label tensor
tp = torch.tensor[(1, 0, 0, 0)]
fp = torch.tensor[(0, 0, 0, 1)]
tn = torch.tensor[(0, 0, 1, 0)]
fn = torch.tensor[(0, 1, 0, 0)]
Now I use those four tensors to filter the image batch and color code the images either green, blue, yellow or red by setting the respective color channel to zero
# not touching the H or W dimension, only batch size dimension and channel dimension
images[tp,0] = 0
images[tp,2] = 0
images[tn,0] = 0
images[tn,1] = 0
images[fp,2] = 0
images[fn,1] = 0
images[fn,2] = 0
and this seems to work. But I am wondering why. Are those four tensors treated as logical tensors for the access?
According to the indexing rules I suspected that only the first and second images are accessed like that.
Then what is the difference to indexing the images with the actual index?
images[0,0] = 0
images[0,2] = 0
images[2,0] = 0
images[2,1] = 0
images[3,2] = 0
images[3,1] = 0
images[1,2] = 0