K-winner-take-all advanced indexing

I’m trying to implement efficient k-winner-take-all activation function in PyTorch.

class KWinnersTakeAll(torch.autograd.Function):

    @staticmethod
    def forward(ctx, tensor, sparsity: float):
        batch_size, embedding_size = tensor.shape
        _, argsort = tensor.sort(dim=1, descending=True)
        k_active = math.ceil(sparsity * embedding_size)
        active_indices = argsort[:, :k_active]
        mask_active = torch.ByteTensor(tensor.shape).zero_()
        for sample_id in range(batch_size):
            mask_active[sample_id, active_indices[sample_id]] = 1
        tensor[~mask_active] = 0
        tensor[mask_active] = 1
        ctx.save_for_backward(mask_active)
        return tensor

What I want is to get rid of for-loop by vectorizing the computations it does. Any idea how to do that?
Let mask_active.shape == (256, 128) and active_indices.shape == (256, 7) then
mask_active[active_indices].shape == (256, 7, 128) but I’d like to have mask_active[active_indices].shape == (256, 7) so then I could assign mask_active[active_indices] = 1

I think you could use advanced indexing for your use case.
Here is a small example with smaller sizes for better understanding:

batch_size = 5
features = 4

idx = torch.empty(batch_size, 2, dtype=torch.long).random_(features)
mask_active = torch.zeros(x.shape, dtype=torch.uint8)
mask_active_fast = torch.zeros(x.shape, dtype=torch.uint8)
# Your approach
for i in range(batch_size):
    mask_active[i, idx[i]] = 1

# Advanced indexing
mask_active_fast[torch.arange(batch_size).unsqueeze(1), idx] = 1

# Check
(mask_active==mask_active_fast).all()
1 Like

Sorry for bringing this up again, but has anybody an idea how to implement the backward path?

or is it simply passing the gradient through, where the mask is one? like this:


    @staticmethod
    def backward(ctx, grad_output):
        mask_active, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[~mask_active] = 0
        return grad_input

@adamk zeroing-out masked values like you do won’t work for complex data (only MNIST works out such a cruel cut-off in my experiments). However, the paper linked below did use this approach and showed that it works.

For further inquiries on how to broadcast the gradients of the k-winners-take-all function backward, refer to my kwta.py script below. The solution is, as always, to approximate the non-differentiable function with a soft approximation (see KWinnersTakeAllSoft class).

For convenience, I made a list of how people implement kWTA in PyTorch and deal with its gradients.

2 Likes