Hi,

I am coding a Monocular Depth Perception Network that follows a bin + residual process. I have a set number of bins which represent a range of values between two numbers and are defined by their left and right edges, and a midpoints (these last ones are encoded in logarithmic space).

They are defined as:

```
import torch
import math
min_depth=1.
max_depth=126.
num_bins=48
alpha = min_depth
beta = max_depth
bins_edges = [math.exp( math.log(alpha) + ( (math.log(beta/alpha)*i) / num_bins ) ) for i in range(num_bins+1)]
midpoints = [ ( math.log(bins_edges[i+1]) + math.log(bins_edges[i]) ) / 2 for i in range(num_bins)]
bins_edges = torch.tensor(bins_edges)
midpoints_log = torch.tensor(midpoints)
```

My network predicts two feature maps, one that I call `bin_logits`

which tells me the probability that a pixel has a depth value within a bin (classification problem), and another called `residuals`

which helps me give an extra value to the depth so it is continuous together with the midpoint defined for the bin using the following formula:

Which means that depending on the bins predicted by `bin_logits`

, I need to take the respective value from the from the `K`

layer in the `residuals`

feature map for each pixel. I need to also take the respective midpoint, left and right edges in order to calculate the depth.

I have been able to use the indexes given by `argmax`

on the `bin_logits`

features to take the midpoints, and also the left and right edges like:

```
#dummy output for a batch of 8 images of size 100x200
bin_logits = torch.rand((8,num_bins,100,200))
residuals = torch.rand((8,num_bins,100,200))
bin_logits = bin_logits.argmax(1)
midpoints = midpoints_log[bin_logits.flatten()].view_as(bin_logits)
left_edges = bins_edges[bin_logits.flatten()].view_as(bin_logits)
right_edges = bins_edges[(bin_logits+1).flatten()].view_as(bin_logits)
```

But I do not know how to do it for the `residuals`

feature map. I tried using `bin_logits.max(dim=1, keep_dim=True)`

, but it gives me indices in the range of the number of bins (0-47 in this case).

I need for the final tensor to be of shape `(8,100,200)`

. So I can apply the equation `(4)`

seen above.

**How can I use the** `argmax`

**from** `bin_logits`

**feature maps to select the corresponding values in the** `residuals`

**tensor?** Like I do for the edges and midpoint.

Thanks for any help!