What I want is to get rid of for-loop by vectorizing the computations it does. Any idea how to do that?
Let mask_active.shape == (256, 128) and active_indices.shape == (256, 7) then mask_active[active_indices].shape == (256, 7, 128) but I’d like to have mask_active[active_indices].shape == (256, 7) so then I could assign mask_active[active_indices] = 1

@adamk zeroing-out masked values like you do won’t work for complex data (only MNIST works out such a cruel cut-off in my experiments). However, the paper linked below did use this approach and showed that it works.

For further inquiries on how to broadcast the gradients of the k-winners-take-all function backward, refer to my kwta.py script below. The solution is, as always, to approximate the non-differentiable function with a soft approximation (see KWinnersTakeAllSoft class).

For convenience, I made a list of how people implement kWTA in PyTorch and deal with its gradients.

my kwta.py script: straightforward hard kWTA (KWinnersTakeAll) and its soft approximation (KWinnersTakeAllSoft).

two implementations of “Haw can we be so dense?” (Subutai Ahmad, Luiz Scheinkman, 2019) paper. Both employ the straightforward hard kWTA.