Torch.gather along one axis

This function is a bit confusing to me. Is it similar to normal indexing if we run along a single axis only?

values = torch.rand(2, 16, 4)
index = torch.tensor([3, 6, 7])
index = index.view(1, -1, 1).expand(values.shape[0], -1, values.shape[2])
values = torch.gather(values, 1, index)


values = values[:, index, :]

Are they exactly the same function?


Yes they are. A simpler version is values.index_select(1, index).