How to use torch.max() as the condition for torch.where()?

I’m trying to figure out how to utilize torch.max() as the condition for torch.where(). My goal is to zero-out the non-max values of a 4D-tensor across the channels of each sample. For example:

# given:
a = torch.arange(3*2*2).view(1, 3, 2, 2)
>>>tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]]]])


# wanted:
tensor([[[[ 0,  0],
          [ 0,  0]],

         [[ 0,  0],
          [ 0,  0]],

         [[ 8,  9],
          [10, 11]]]])

I’m trying to use some form of: torch.where(condition=a.max(dim=1), a, torch.zeros_like(a)) , but I can’t quite figure out how to get it to work.

Does anyone know a nice approach to do this?

This code should work:

a = torch.arange(3*2*2).view(1, 3, 2, 2)
val, idx = a.max(1, keepdim=True)
z = torch.zeros_like(a)
z[torch.arange(z.size(0))[:, None], idx] = val
print(z)
> tensor([[[[ 0,  0],
            [ 0,  0]],

           [[ 0,  0],
            [ 0,  0]],

           [[ 8,  9],
            [10, 11]]]])
1 Like