# How to use `torch.topk` to select the maximum sum of dim=1 on a 4d tensor without loops

For a 4d tensor of shape `(N, C, H, W)`, I’m trying to select top k channels from `dim=1` based on the sum of each channel, and zero out the non-selected channels for each `n` in `N`.

I can easily do this with a nested for-loop:

``````In [39]: x = torch.rand(2, 3, 2, 2)
In [40]: x
Out[40]:
tensor([[[[0.0432, 0.1441],
[0.4919, 0.4644]],

[[0.2913, 0.5852],
[0.6561, 0.0557]],

[[0.8833, 0.7226],
[0.4892, 0.5529]]],

[[[0.2340, 0.2637],
[0.0494, 0.9076]],

[[0.3043, 0.2380],
[0.6766, 0.6793]],

[[0.7904, 0.2771],
[0.1928, 0.7959]]]])

In [41]: activation = x.sum(2).sum(2)

In [42]: topk, indices = torch.topk(activation, 2, dim=1)

In [43]: for i, _ in enumerate(indices):
...:     for j, _ in enumerate(x[i, :, :, :]):
...:         if j not in indices[i]:
...:             x[i, j, :, :] = 0

In [44]: x
Out[44]:
tensor([[[[0.0000, 0.0000],
[0.0000, 0.0000]],

[[0.2913, 0.5852],
[0.6561, 0.0557]],

[[0.8833, 0.7226],
[0.4892, 0.5529]]],

[[[0.0000, 0.0000],
[0.0000, 0.0000]],

[[0.3043, 0.2380],
[0.6766, 0.6793]],

[[0.7904, 0.2771],
[0.1928, 0.7959]]]])
``````

I would like to achieve this without any for-loops. Is it possible to use `torch.gather` for this?

I was able to reduce the complexity to one single for loop using a selection tensor:

``````activation = x.sum(dim=2).sum(dim=2)
topk, indices = torch.topk(activation, self.k, dim=1)
selection_tensor = torch.zeros_like(x)

for i, _ in enumerate(indices):
selection_tensor[i, indices[i], :, :] = 1

x = x * selection_tensor
``````

This helps a lot but still it’d be nice if we can throw away the loop all together.

Instead of the for loop to create your `selection_tensor`, you could also use indexing:

``````selection_tensor[torch.arange(selection.size(0)), indices.t()] = 1
``````
1 Like