This function is a bit confusing to me. Is it similar to normal indexing if we run along a single axis only?
values = torch.rand(2, 16, 4)
index = torch.tensor([3, 6, 7])
index = index.view(1, -1, 1).expand(values.shape[0], -1, values.shape[2])
values = torch.gather(values, 1, index)
and
values = values[:, index, :]
Are they exactly the same function?