Hi,
I am coding a Monocular Depth Perception Network that follows a bin + residual process. I have a set number of bins which represent a range of values between two numbers and are defined by their left and right edges, and a midpoints (these last ones are encoded in logarithmic space).
They are defined as:
import torch
import math
min_depth=1.
max_depth=126.
num_bins=48
alpha = min_depth
beta = max_depth
bins_edges = [math.exp( math.log(alpha) + ( (math.log(beta/alpha)*i) / num_bins ) ) for i in range(num_bins+1)]
midpoints = [ ( math.log(bins_edges[i+1]) + math.log(bins_edges[i]) ) / 2 for i in range(num_bins)]
bins_edges = torch.tensor(bins_edges)
midpoints_log = torch.tensor(midpoints)
My network predicts two feature maps, one that I call bin_logits
which tells me the probability that a pixel has a depth value within a bin (classification problem), and another called residuals
which helps me give an extra value to the depth so it is continuous together with the midpoint defined for the bin using the following formula:
Which means that depending on the bins predicted by bin_logits
, I need to take the respective value from the from the K
layer in the residuals
feature map for each pixel. I need to also take the respective midpoint, left and right edges in order to calculate the depth.
I have been able to use the indexes given by argmax
on the bin_logits
features to take the midpoints, and also the left and right edges like:
#dummy output for a batch of 8 images of size 100x200
bin_logits = torch.rand((8,num_bins,100,200))
residuals = torch.rand((8,num_bins,100,200))
bin_logits = bin_logits.argmax(1)
midpoints = midpoints_log[bin_logits.flatten()].view_as(bin_logits)
left_edges = bins_edges[bin_logits.flatten()].view_as(bin_logits)
right_edges = bins_edges[(bin_logits+1).flatten()].view_as(bin_logits)
But I do not know how to do it for the residuals
feature map. I tried using bin_logits.max(dim=1, keep_dim=True)
, but it gives me indices in the range of the number of bins (0-47 in this case).
I need for the final tensor to be of shape (8,100,200)
. So I can apply the equation (4)
seen above.
How can I use the argmax
from bin_logits
feature maps to select the corresponding values in the residuals
tensor? Like I do for the edges and midpoint.
Thanks for any help!