Use argmax to get values from another tensor


Currently I am trying to implement a network where a set of convolution weights are predicted. Link to paper

First, some part of the network (category branch) predicts the class and its probability score for a set of cells which represent the image. Based on this probability score, the network decides which cells will use their predicted weights (kernel branch) to be used in the final output.

For example, let’s say I have a grid of 10x10 on the category branch, with 8 possible classes. And my kernel branch will predict the weights for a 1x1 convolution applied to a feature map with depth of 256 for each of these cells in the grid. Here I am assuming a batch size of 4.

Here is the code I have so far:

import torch
import torch.nn.functional as F

cat_branch = torch.rand(4,8,10,10).cuda()*0.55 
# I'm multiplying by 0.55 just to have some False
# values when applying a threshold
kernel_branch = torch.rand(4,256,10,10).cuda()
feature_map = torch.rand(4,256,128,256).cuda()

# I find the biggest values and apply my 
# threshold to see which ones I keep
values, idx = cat_pred.max(dim=1, keepdim=True)
values_thr = values > 0.5
values_thr =
values_thr = values_thr.nonzero()

#Then I just take the indexes of those with respect to the 10x10 grid 
#and get the weights from kernel_branch for each batch

weights_conv = [torch.FloatTensor().cuda() for x in range(values_thr[-1,0]+1)]

for v in values_thr:
    weights_conv[v[0]] =[v[0]],kernel_pred[v[0],:,v[2],v[3]]),0)

#Then, using the predicted weights I perform 
#the Convolution on the feature map
for i in range(len(feature_map)):
    f = feature_map[i].unsqueeze(0)
    weights = weights_conv[i].view( len(weights_conv[i])//feature_map.shape[1], feature_map.shape[1],1,1)
    seg_preds = F.conv2d(f.cuda(), weights.cuda(), stride=1).sigmoid()

So what I want to do is use the highest scoring cells on the category branch to select the set of filters from the kernel branch and then applying them without having to use these for-loops.

Does anyone have any ideas on how to do this?