For 2d torch gather, specifying dim 0 means that for each entry along dim 0, indices are used to gather elements on dim 1. Say dim 0 is the batch dimension, the indices in this case will be used to gather along the features dimension.
For 3d torch gather, specifying dim 0 means that indices are directly applied to dim 0.
The 3d implementation actually generalizes to tensors of all dimensions. Why do we have the 2d gather special case?
Goal: get probability ratings for each of the idx by gather it from source_tensor.
In 2d case, the approach should be torch.gather(source_tensor, 0, idx). This is because in order to sample along prob_buckets axies, we need to sample along the 0 dimension. My understanding is that dim=0 means that for each entry [prob_buckets] along the batch dimension gets sampled by the idx tensor.
My understanding could be incorrect but if this is the case, then the 3d example with dim=0 would sample along the batch dimension for each prob_buckets.
I’m still unsure what exactly is inconsistent.
As described in the docs the gather operation would be applied as:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
for a 3D case (and the same logic would still work for the 2D case by removing the k dimension).
Here is a small example:
# 3D
i, j, k = 2, 3, 4
x = torch.randn(i, j, k)
idx = torch.randint(0, i, (i, j, k))
# gather
out = torch.gather(x, dim=0, index=idx)
# manual
out_manual = torch.zeros(i, j, k)
for i_ in range(i):
for j_ in range(j):
for k_ in range(k):
out_manual[i_, j_, k_] = x[idx[i_, j_, k_], j_, k_]
# compare
print((out==out_manual).all())
# > tensor(True)
# 2D
x = torch.randn(i, j)
idx = torch.randint(0, i, (i, j))
# gather
out = torch.gather(x, dim=0, index=idx)
# manual
out_manual = torch.zeros(i, j)
for i_ in range(i):
for j_ in range(j):
out_manual[i_, j_] = x[idx[i_, j_], j_]
# compare
print((out==out_manual).all())
# > tensor(True)