Why `torch.gather` has not input shape check?

The document of torch.gather says the index argument must be an n-dimensional tensor with some certain shape. I thought this means that it will check the input dimension but it turned out it didn’t check and did the unexpected thing silently.

The following code has a bug, but it runs without even a warning.

import torch
torch.manual_seed(0)
input = torch.rand(4, 2)
index = torch.randint(2, size=(4,)).unsqueeze(0)  # intended to be unsqueeze(1)
dim = 1
output = torch.gather(input, dim, index)
print("input = ", input)
print("index = ", index)
print("output = ", output)

I thought it would be good if we check the index shape. Otherwise, the documentation should mention that input dimension is not checked.

It seems the check was dropped somewhere between PyTorch 1.5.1 and 1.6.
Would you mind creating an issue on GitHub so that we can track it?

Here is the Github issue: https://github.com/pytorch/pytorch/issues/47610

1 Like