Let’s say you have a tensor like this:
mytensor = torch.tensor([[[-11,0,101,0],[0,5,50,0],[-1,-2,0,1.]]])
And you define your mask as being 0: mask = mytensor.eq(0.)
Now you want to apply in dim 1 several metrics such has max, min, mean (and potentially custom one).
Doing something like this:
mytensor.mean(1)
tensor([[-4.0000, 1.0000, 50.3333, 0.3333]])
Is not going to work for example for the first value it should be (-11-1)/2=-6
You can hack getting the mean like this: mytensor.sum(1).div((~mask).sum(1).float())
but this is not going to work for min
and max
or any custom operation.
nn.CrossEntropyLoss
has a parameter ignore_index
to resolve that but I can’t see any way to do it for generic operation. Is there a clean non hacky way to do it?