How to select multiple indexes over multiple dimensions at the same time?

Hi,

I would like to know if it is possible to access multiple indexes across multiple dimensions in a single line using advance indexing/broadcasting techniques.

My use case is the following, I have an input image tensor (N, C, H_in, W_in), I have valid indexes tensors (N, H_out, W_out) for the height and width axes of the input image and I want to build an output image tensor (N, C, H_out, W_out) made out of the input image sampled at the indexes tensors.

The code below is working, but the first method is doing a Python loop along the batch dimension, which isn’t time efficient, and the second method is creating a “too big” and then shrinking it, which isn’t memory efficient.

import torch

# Input tensor (N, C, H_in, W_in)
n, c, h_in, w_in = 8, 3, 32, 32

input = torch.randn(n, c, h_in, w_in)

# Indexes tensors (N, H_out, W_out), containing indexes between [0, .., W_in - 1] and [0, .., H_in - 1] respectively.
h_out, w_out = 64, 64

idx_x = torch.rand(n, h_out, w_out) * (w_in - 1)
idx_x = idx_x.round().long()

idx_y = torch.rand(n, h_out, w_out) * (h_in - 1)
idx_y = idx_y.round().long()


print(input.shape)
print(idx_x.shape)
print(idx_y.shape)

# Method 1.
# Creating the output tensor (N, C, H_out, W_out) by iterating over the batch dimension first.
output = []
for input_, idx_x_, idx_y_ in zip(input, idx_x, idx_y):
    output += [input_[:, idx_y_, idx_x_]]
output = torch.stack(output, dim=0)
print(output.shape)

# Method 2.
# We first produce a (N, C, N, H_out, W_out).
# And then retain values over the "diagonal" in order to obtain an (N, C, H_out, W_out).
output_ = input[:, :, idx_y, idx_x]
output_ = output_.diagonal(dim1=0, dim2=2).permute(dims=[3, 0, 1, 2])
print(output_.shape)

# Check that the two tensors are equal.
print(output_.allclose(output))

Is there a third method, e.g a one-line trick, with advanced broadcasting or anything that isn’t a for-loop and creates the correct tensor?

Thanks in advance,

Guillaume

Hi,

Sure,
The code sample is very helpful!

Here is a method 3 that will do it in one op:

# Method 3
# Ensure the view ops below will be valid
assert input.is_contiguous()
assert idx_x.is_contiguous()
assert idx_y.is_contiguous()
# Linearize the dimension you want to index
lin_input = input.view(n, c, -1)
lin_idx_x = idx_x.view(n, -1)
lin_idx_y = idx_y.view(n, -1)

# Compute indices in lin_input
lin_indices = lin_idx_y * h_in + lin_idx_x

# Add channel dimension that because we want all entries there
lin_indices = lin_indices.unsqueeze(1).expand(n, c, -1)

# Get the values
lin_out = lin_input.gather(-1, lin_indices)

# Un linearize the result
out = lin_out.view(n, c, h_out, w_out)

print(out.allclose(output))

Awesome, thanks a lot!

Quick question, if we don’t know whether input, idx_x, and idx_y are contiguous, can we use .reshape() as an alternative?

Cheers,

Guillaume

No you should do input = input.contiguous().

The view operation and the linear indices trick done here relies on the fact that the Tensor is contiguous in memory. And you can reshape Tensors that are not contiguous.

Nothing serious, but there’s a small mistake in there, I’ve tried with different shapes:

n, c, h_in, w_in = 8, 3, 32, 64
h_out, w_out = 256, 128

And the results don’t match.
However if we fix the line:

# lin_indices = lin_idx_y * h_in + lin_idx_x
lin_indices = lin_idx_y * w_in + lin_idx_x

Here’s the full script for clarity:

import torch

# Input tensor (N, C, H_in, W_in)
n, c, h_in, w_in = 8, 3, 32, 64

input = torch.randn(n, c, h_in, w_in)

# Indexes tensors (N, H_out, W_out), containing indexes between [0, .., W_in - 1] and [0, .., H_in - 1] respectively.
h_out, w_out = 256, 128

idx_x = torch.rand(n, h_out, w_out) * (w_in - 1)
idx_x = idx_x.round().long()

idx_y = torch.rand(n, h_out, w_out) * (h_in - 1)
idx_y = idx_y.round().long()


print(input.shape)
print(idx_x.shape)
print(idx_y.shape)

# Method 1.
# Creating the output tensor (N, C, H_out, W_out) by iterating over the batch dimension first.
output = []
for input_, idx_x_, idx_y_ in zip(input, idx_x, idx_y):
    output += [input_[:, idx_y_, idx_x_]]
output = torch.stack(output, dim=0)
print(output.shape)

# Method 2.
# We first produce a (N, C, N, H_out, W_out).
# And then retain values over the "diagonal" in order to obtain an (N, C, H_out, W_out).
output_ = input[:, :, idx_y, idx_x]
output_ = output_.diagonal(dim1=0, dim2=2).permute(dims=[3, 0, 1, 2])
print(output_.shape)

# Check that the two tensors are equal.
print(output_.allclose(output))

# Method 3
# Ensure the view ops below will be valid
assert input.is_contiguous()
assert idx_x.is_contiguous()
assert idx_y.is_contiguous()
# Linearize the dimension you want to index
lin_input = input.view(n, c, -1)
lin_idx_x = idx_x.view(n, -1)
lin_idx_y = idx_y.view(n, -1)

# Compute indices in lin_input
# lin_indices = lin_idx_y * h_in + lin_idx_x 
lin_indices = lin_idx_y * w_in + lin_idx_x

# Add channel dimension that because we want all entries there
lin_indices = lin_indices.unsqueeze(1).expand(n, c, -1)

# Get the values
lin_out = lin_input.gather(-1, lin_indices)

# Un linearize the result
out = lin_out.view(n, c, h_out, w_out)

print(out.allclose(output))

Ho good catch!
I missed that one.

Happy that you found the fix easily!