How to use array as indices to take values from another array


I am coding a Monocular Depth Perception Network that follows a bin + residual process. I have a set number of bins which represent a range of values between two numbers and are defined by their left and right edges, and a midpoints (these last ones are encoded in logarithmic space).

They are defined as:

import torch
import math


alpha = min_depth
beta = max_depth
bins_edges = [math.exp( math.log(alpha) + ( (math.log(beta/alpha)*i) / num_bins ) ) for i in range(num_bins+1)]
midpoints = [ ( math.log(bins_edges[i+1]) + math.log(bins_edges[i]) ) / 2 for i in range(num_bins)]
bins_edges = torch.tensor(bins_edges)
midpoints_log = torch.tensor(midpoints)

My network predicts two feature maps, one that I call bin_logits which tells me the probability that a pixel has a depth value within a bin (classification problem), and another called residuals which helps me give an extra value to the depth so it is continuous together with the midpoint defined for the bin using the following formula:

Which means that depending on the bins predicted by bin_logits, I need to take the respective value from the from the K layer in the residuals feature map for each pixel. I need to also take the respective midpoint, left and right edges in order to calculate the depth.

I have been able to use the indexes given by argmax on the bin_logits features to take the midpoints, and also the left and right edges like:

#dummy output for a batch of 8 images of size 100x200
bin_logits = torch.rand((8,num_bins,100,200))
residuals = torch.rand((8,num_bins,100,200))

bin_logits = bin_logits.argmax(1)
midpoints = midpoints_log[bin_logits.flatten()].view_as(bin_logits)
left_edges = bins_edges[bin_logits.flatten()].view_as(bin_logits)
right_edges = bins_edges[(bin_logits+1).flatten()].view_as(bin_logits) 

But I do not know how to do it for the residuals feature map. I tried using bin_logits.max(dim=1, keep_dim=True), but it gives me indices in the range of the number of bins (0-47 in this case).

I need for the final tensor to be of shape (8,100,200). So I can apply the equation (4) seen above.

How can I use the argmax from bin_logits feature maps to select the corresponding values in the residuals tensor? Like I do for the edges and midpoint.

Thanks for any help!

I found the answer using torch.gather:

residuals = torch.gather(residuals, 1, bin_logits.unsqueeze(1)).squeeze(1)

What I needed to do was to make the number of dimensions for the index array and the residuals equal. Then collapse the expanded dimension, et voilá!