Masking Batch-Wise Max!

One solution, inspired from How to efficiently normalize a batch of tensor to [0, 1], is as follows.

import torch
batch_size, height, width = 2, 2, 2
torch.manual_seed(0)
loss = torch.randn((batch_size, height, width))
print(loss)

# loss.size(0) or batch_size
loss = loss.view(loss.size(0), -1) # tensor of shape (batch_size, height * width)
bound = loss.max(dim = 1, keepdim=True)[0] * 0.9
mask = torch.zeros_like(loss)
mask[loss < bound] = 1.0   # set the values i want to keep to 1
mask[loss >= bound] = 0.0 
loss *= mask
loss = loss.view(batch_size, height, width) # tensor of shape (batch_size, height, width)

print(loss)

I’ve set the random seed just to ensure reproducibility, the following code (yours) gives the same result

torch.manual_seed(0)
loss = torch.randn((batch_size, height, width))
print(loss)

for i in range(batch_size) :
    bound = torch.max(loss[i]) * 0.9
    mask = torch.zeros_like(loss[i])
    mask[loss[i] < bound] = 1.0  # set the values i want to keep to 1
    mask[loss[i] >= bound] = 0.0 
    loss[i] *= mask 

print(loss)
2 Likes