The document of
torch.gather says the
index argument must be an
n-dimensional tensor with some certain shape. I thought this means that it will check the input dimension but it turned out it didn’t check and did the unexpected thing silently.
The following code has a bug, but it runs without even a warning.
import torch torch.manual_seed(0) input = torch.rand(4, 2) index = torch.randint(2, size=(4,)).unsqueeze(0) # intended to be unsqueeze(1) dim = 1 output = torch.gather(input, dim, index) print("input = ", input) print("index = ", index) print("output = ", output)
I thought it would be good if we check the
index shape. Otherwise, the documentation should mention that input dimension is not checked.