How to use `torch.topk` to select the maximum sum of dim=1 on a 4d tensor without loops

For a 4d tensor of shape (N, C, H, W), I’m trying to select top k channels from dim=1 based on the sum of each channel, and zero out the non-selected channels for each n in N.

I can easily do this with a nested for-loop:

In [39]: x = torch.rand(2, 3, 2, 2)
In [40]: x
Out[40]: 
tensor([[[[0.0432, 0.1441],
          [0.4919, 0.4644]],

         [[0.2913, 0.5852],
          [0.6561, 0.0557]],

         [[0.8833, 0.7226],
          [0.4892, 0.5529]]],


        [[[0.2340, 0.2637],
          [0.0494, 0.9076]],

         [[0.3043, 0.2380],
          [0.6766, 0.6793]],

         [[0.7904, 0.2771],
          [0.1928, 0.7959]]]])

In [41]: activation = x.sum(2).sum(2)

In [42]: topk, indices = torch.topk(activation, 2, dim=1)

In [43]: for i, _ in enumerate(indices):
    ...:     for j, _ in enumerate(x[i, :, :, :]):
    ...:         if j not in indices[i]:
    ...:             x[i, j, :, :] = 0

In [44]: x
Out[44]: 
tensor([[[[0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.2913, 0.5852],
          [0.6561, 0.0557]],

         [[0.8833, 0.7226],
          [0.4892, 0.5529]]],


        [[[0.0000, 0.0000],
          [0.0000, 0.0000]],

         [[0.3043, 0.2380],
          [0.6766, 0.6793]],

         [[0.7904, 0.2771],
          [0.1928, 0.7959]]]])

I would like to achieve this without any for-loops. Is it possible to use torch.gather for this?

I was able to reduce the complexity to one single for loop using a selection tensor:

activation = x.sum(dim=2).sum(dim=2)
topk, indices = torch.topk(activation, self.k, dim=1)
selection_tensor = torch.zeros_like(x)
            
for i, _ in enumerate(indices):
    selection_tensor[i, indices[i], :, :] = 1
            
x = x * selection_tensor

This helps a lot but still it’d be nice if we can throw away the loop all together.

Instead of the for loop to create your selection_tensor, you could also use indexing:

selection_tensor[torch.arange(selection.size(0)), indices.t()] = 1
1 Like