You will have problems regarding one of the inner dimensions, as you won’t have always the same number of elements
here’s a little over-simplified example :
t = torch.Tensor(2, 2).random_(0, 2).long()
# t = [ 0, 0,
# 1, 0] (long tensor)
m = t == 1
# m = [[0, 0],
# [1, 0]] (byte tensor this time)
et = torch.Tensor(2, 2, 2).zero_()
et.scatter_(1, t.unsqueeze(1), 1)
# et = [[[1, 1],
# [0, 0]],
# [[0, 1],
# [1, 0]]] (float)
m = m.unsqueeze(1).expand(et.size())
# ~m = [[[1, 1],
# [1, 1]],
# [[0, 1],
# [0, 1]]] (byte)
#et[~m] = [1,1,0,0,1,0]
as you can see, et
still has the information you need, but because everything is flatten, you lose dimensional information. What you want is to get
et[~m] = [[[1,0],
[1,0]],
[[1,0]]]
And this not a regular array as second dimension is 2 then 1
The right way to do it is making a one hot encoding for the first dimension :
# target tensor
t = torch.Tensor(4, 5, 5).random_(0, 10).long()
# binary mask
m = t == 3
# one hot encoded tensor
et = torch.Tensor(10, 4, 5, 5).zero_()
# one hot encoding
et.scatter_(0, t.unsqueeze(0), 1)
# expand the binary mask to match the new one-hot encoded target
m = m.unsqueeze(0).expand(et.size())
# prints a vector of ones !
print(et[~m].view(10, -1).sum(0))
If you still want to keep batch wise operation, you actually need to treat each batch separately with a loop because batches won’t necessarily be the same size