Hello, I’m kind-off new into pytorch. I’ve been trying to implement a forward method that takes a tensor say V of n values in [0,1], picks the top k indices referring to the top-k elements in the tensor, then outputs a tensor of the same dimension as V with the top_k values set to 1, and the rest set to 0. k is a string (in [0,1]) given to the forward method. My implementation is as follows. It seems to work but seems a bit slow, is there another implementation that could offer speedups? Here’s the code :
class Top_k(torch.autograd.Function): @staticmethod def forward(ctx, V,k): a = V.view(1,-1).squeeze() ind = torch.topk(a , round(float(k) * a .shape[0]))[1] a = a*0 a[ind] = a[ind] + 1 a = a.view(V.shape) ctx.save_for_backward(V) return a @staticmethod def backward(ctx, grad_output): pvals = ctx.saved_tensors return pvals[0] * grad_output, None
Thanks