For a 4d tensor of shape (N, C, H, W)
, I’m trying to select top k channels from dim=1
based on the sum of each channel, and zero out the non-selected channels for each n
in N
.
I can easily do this with a nested for-loop:
In [39]: x = torch.rand(2, 3, 2, 2)
In [40]: x
Out[40]:
tensor([[[[0.0432, 0.1441],
[0.4919, 0.4644]],
[[0.2913, 0.5852],
[0.6561, 0.0557]],
[[0.8833, 0.7226],
[0.4892, 0.5529]]],
[[[0.2340, 0.2637],
[0.0494, 0.9076]],
[[0.3043, 0.2380],
[0.6766, 0.6793]],
[[0.7904, 0.2771],
[0.1928, 0.7959]]]])
In [41]: activation = x.sum(2).sum(2)
In [42]: topk, indices = torch.topk(activation, 2, dim=1)
In [43]: for i, _ in enumerate(indices):
...: for j, _ in enumerate(x[i, :, :, :]):
...: if j not in indices[i]:
...: x[i, j, :, :] = 0
In [44]: x
Out[44]:
tensor([[[[0.0000, 0.0000],
[0.0000, 0.0000]],
[[0.2913, 0.5852],
[0.6561, 0.0557]],
[[0.8833, 0.7226],
[0.4892, 0.5529]]],
[[[0.0000, 0.0000],
[0.0000, 0.0000]],
[[0.3043, 0.2380],
[0.6766, 0.6793]],
[[0.7904, 0.2771],
[0.1928, 0.7959]]]])
I would like to achieve this without any for-loops. Is it possible to use torch.gather
for this?