Topk based on weight matrix destroys gradient

I have a matrix of features feat = [n_shot, n_way, n_dim] and a weight matrix weight = [n_shot, n_way]. I want to take the top-k features from the first dimension, according to the weight matrix. Since top-k doesn’t support this natively, I’ve implemented it in the following way:

topk, ind = torch.topk(sim, k, dim=0)
feat_topk = Variable(torch.zeros(k, n_way, n_dim))
for way in range(n_way):
   for shot in range(k):
     feat_topk[shot][way] = feat[ind[shot][way]][way]

This works, but obviously the in-place assignment means that feat_topk has no gradient information, which I would definitely like to have as feat_topk is then used to directly compute the loss. I know torch.topk is backprop-able, but I’m not sure how to make a version of this that is. I would appreciate any help on this, thank you!