Speedup forward/backward propagation

Hello, I’m kind-off new into pytorch. I’ve been trying to implement a forward method that takes a tensor say V of n values in [0,1], picks the top k indices referring to the top-k elements in the tensor, then outputs a tensor of the same dimension as V with the top_k values set to 1, and the rest set to 0. k is a string (in [0,1]) given to the forward method. My implementation is as follows. It seems to work but seems a bit slow, is there another implementation that could offer speedups? Here’s the code :

class Top_k(torch.autograd.Function):
    @staticmethod
    def forward(ctx, V,k):

        a = V.view(1,-1).squeeze()
        ind = torch.topk(a , round(float(k) * a .shape[0]))[1]
        a = a*0
        a[ind] = a[ind] + 1
        a = a.view(V.shape)
        ctx.save_for_backward(V)
        return a

    @staticmethod
    def backward(ctx, grad_output):
        pvals = ctx.saved_tensors
        return pvals[0] * grad_output, None

Thanks