For a 4d tensor of shape `(N, C, H, W)`

, I’m trying to select top k channels from `dim=1`

based on the sum of each channel, and zero out the non-selected channels for each `n`

in `N`

.

I can easily do this with a nested for-loop:

```
In [39]: x = torch.rand(2, 3, 2, 2)
In [40]: x
Out[40]:
tensor([[[[0.0432, 0.1441],
[0.4919, 0.4644]],
[[0.2913, 0.5852],
[0.6561, 0.0557]],
[[0.8833, 0.7226],
[0.4892, 0.5529]]],
[[[0.2340, 0.2637],
[0.0494, 0.9076]],
[[0.3043, 0.2380],
[0.6766, 0.6793]],
[[0.7904, 0.2771],
[0.1928, 0.7959]]]])
In [41]: activation = x.sum(2).sum(2)
In [42]: topk, indices = torch.topk(activation, 2, dim=1)
In [43]: for i, _ in enumerate(indices):
...: for j, _ in enumerate(x[i, :, :, :]):
...: if j not in indices[i]:
...: x[i, j, :, :] = 0
In [44]: x
Out[44]:
tensor([[[[0.0000, 0.0000],
[0.0000, 0.0000]],
[[0.2913, 0.5852],
[0.6561, 0.0557]],
[[0.8833, 0.7226],
[0.4892, 0.5529]]],
[[[0.0000, 0.0000],
[0.0000, 0.0000]],
[[0.3043, 0.2380],
[0.6766, 0.6793]],
[[0.7904, 0.2771],
[0.1928, 0.7959]]]])
```

I would like to achieve this without any for-loops. Is it possible to use `torch.gather`

for this?