Hi,
I would like to know if it is possible to access multiple indexes across multiple dimensions in a single line using advance indexing/broadcasting techniques.
My use case is the following, I have an input image tensor (N, C, H_in, W_in), I have valid indexes tensors (N, H_out, W_out) for the height and width axes of the input image and I want to build an output image tensor (N, C, H_out, W_out) made out of the input image sampled at the indexes tensors.
The code below is working, but the first method is doing a Python loop along the batch dimension, which isn’t time efficient, and the second method is creating a “too big” and then shrinking it, which isn’t memory efficient.
import torch
# Input tensor (N, C, H_in, W_in)
n, c, h_in, w_in = 8, 3, 32, 32
input = torch.randn(n, c, h_in, w_in)
# Indexes tensors (N, H_out, W_out), containing indexes between [0, .., W_in - 1] and [0, .., H_in - 1] respectively.
h_out, w_out = 64, 64
idx_x = torch.rand(n, h_out, w_out) * (w_in - 1)
idx_x = idx_x.round().long()
idx_y = torch.rand(n, h_out, w_out) * (h_in - 1)
idx_y = idx_y.round().long()
print(input.shape)
print(idx_x.shape)
print(idx_y.shape)
# Method 1.
# Creating the output tensor (N, C, H_out, W_out) by iterating over the batch dimension first.
output = []
for input_, idx_x_, idx_y_ in zip(input, idx_x, idx_y):
output += [input_[:, idx_y_, idx_x_]]
output = torch.stack(output, dim=0)
print(output.shape)
# Method 2.
# We first produce a (N, C, N, H_out, W_out).
# And then retain values over the "diagonal" in order to obtain an (N, C, H_out, W_out).
output_ = input[:, :, idx_y, idx_x]
output_ = output_.diagonal(dim1=0, dim2=2).permute(dims=[3, 0, 1, 2])
print(output_.shape)
# Check that the two tensors are equal.
print(output_.allclose(output))
Is there a third method, e.g a one-line trick, with advanced broadcasting or anything that isn’t a for-loop and creates the correct tensor?
Thanks in advance,
Guillaume