The document of torch.gather
says the index
argument must be an n
-dimensional tensor with some certain shape. I thought this means that it will check the input dimension but it turned out it didn’t check and did the unexpected thing silently.
The following code has a bug, but it runs without even a warning.
import torch
torch.manual_seed(0)
input = torch.rand(4, 2)
index = torch.randint(2, size=(4,)).unsqueeze(0) # intended to be unsqueeze(1)
dim = 1
output = torch.gather(input, dim, index)
print("input = ", input)
print("index = ", index)
print("output = ", output)
I thought it would be good if we check the index
shape. Otherwise, the documentation should mention that input dimension is not checked.