Masking Batch-Wise Max!

Given a batch of error maps, I would like to mask out error values based on the max of each error map in the batch.

Using torch.max currently gives the single maximum value in the whole batch of error maps. Masking using this single value would result in inconsistent masks with respect to each error map.
Example Input: Loss Map shape: (8, 256, 640) (Batch, Height, Width)
Return: (8, 256, 640) with masked out pixels.

Current Implementation:

bound = torch.max(loss[i]) * 0.9          # Find threshold
mask = torch.zeros_like(loss[i])          # create mask
mask[loss[i] < bound] = 1.0                # set the values i want to keep to 1
mask[loss[i] >= bound] = 0.0             # set others to 0
loss[i] *= mask                                  # Mask!

I am looking for an efficient way to do this in PyTorch instead of a naive approach of using a for loop for each image in a batch.


One solution, inspired from How to efficiently normalize a batch of tensor to [0, 1], is as follows.

import torch
batch_size, height, width = 2, 2, 2
loss = torch.randn((batch_size, height, width))

# loss.size(0) or batch_size
loss = loss.view(loss.size(0), -1) # tensor of shape (batch_size, height * width)
bound = loss.max(dim = 1, keepdim=True)[0] * 0.9
mask = torch.zeros_like(loss)
mask[loss < bound] = 1.0   # set the values i want to keep to 1
mask[loss >= bound] = 0.0 
loss *= mask
loss = loss.view(batch_size, height, width) # tensor of shape (batch_size, height, width)


I’ve set the random seed just to ensure reproducibility, the following code (yours) gives the same result

loss = torch.randn((batch_size, height, width))

for i in range(batch_size) :
    bound = torch.max(loss[i]) * 0.9
    mask = torch.zeros_like(loss[i])
    mask[loss[i] < bound] = 1.0  # set the values i want to keep to 1
    mask[loss[i] >= bound] = 0.0 
    loss[i] *= mask 


In case it’s useful for future readers, here’s a generalized function inspired by the solution from @pascal_notsawo. In my-use case, I needed the 3D indices of each sample’s maximum value (as opposed to masking them in the original tensor like the OP requested). I also handled the edge case of multiple, identical maximum maximum values

from collections import defaultdict
import torch

def sample_max_idxs(batch, on_duplicate='first'):
    """Get indices of maximum sample values

    A combination of operations inspired from the following posts

    :param torch.Tensor[N, ...] batch: Batch of ND tensor samples
    :param str on_duplicate: how to handle exact tie in maximum values. Valid options are
        * first: keep the lowest 2D index
        * last: keep the highest 2D index
        * None: don't remove duplicates
    :return torch.Tensor[N, len(batch.shape)] idxs: Coordinates for the max value within
        each sample of the batch

    # Compress all secondary axes
    batch_size, *other_axes = batch.shape
    batch = batch.view(batch_size, -1)

    # Create a 1D mask for the max indices
    max_per_sample = batch.max(dim=1, keepdim=True)[0]
    mask = torch.zeros_like(batch)
    mask[batch == max_per_sample] = 1

    # Reshape to original dims and use ``nonzero`` binary trick to generate ND indices
    mask = mask.view(batch_size, *other_axes)
    idxs = torch.nonzero(mask * mask)

    # Handle duplicates
    batch_idxs = torch.sort(idxs[:, 0])[0]
    if on_duplicate is None or batch_idxs.numel() == batch_size:
        return idxs
    valid_duplicates = ['first', 'last']
    if on_duplicate not in valid_duplicates:
        raise ValueError(f'Expected on_duplicate to be one of {",".join(valid_duplicates)}, got {on_duplicate}')
    batch2locs = defaultdict(list)
    offset = 0
    for i, batch_idx in enumerate(batch_idxs):
        if i + offset != batch_idx:
            offset -= 1
    remove = []
    for locs in batch2locs.values():
        if len(locs) > 1:
            keep_idx = 0 if on_duplicate == 'first' else len(locs) - 1
            remove.extend([loc for i, loc in enumerate(locs) if i != keep_idx])
    idxs = [sample_idx for i, sample_idx in enumerate(idxs) if i not in remove]
    return torch.stack(idxs)