Inconsistent gather API

For 2d torch gather, specifying dim 0 means that for each entry along dim 0, indices are used to gather elements on dim 1. Say dim 0 is the batch dimension, the indices in this case will be used to gather along the features dimension.

For 3d torch gather, specifying dim 0 means that indices are directly applied to dim 0.

The 3d implementation actually generalizes to tensors of all dimensions. Why do we have the 2d gather special case?

Could you post examples for the 2D and 3D case which are showing this inconsistent behavior, please?

2D example:
source tensor = [batch, prob_buckets]
idx = [batch, num_samples]

Goal: get probability ratings for each of the idx by gather it from source_tensor.

In 2d case, the approach should be torch.gather(source_tensor, 0, idx). This is because in order to sample along prob_buckets axies, we need to sample along the 0 dimension. My understanding is that dim=0 means that for each entry [prob_buckets] along the batch dimension gets sampled by the idx tensor.

My understanding could be incorrect but if this is the case, then the 3d example with dim=0 would sample along the batch dimension for each prob_buckets.

I’m still unsure what exactly is inconsistent.
As described in the docs the gather operation would be applied as:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

for a 3D case (and the same logic would still work for the 2D case by removing the k dimension).

Here is a small example:

# 3D
i, j, k = 2, 3, 4
x = torch.randn(i, j, k)
idx = torch.randint(0, i, (i, j, k))

# gather
out = torch.gather(x, dim=0, index=idx)

# manual
out_manual = torch.zeros(i, j, k)
for i_ in range(i):
    for j_ in range(j):
        for k_ in range(k):
            out_manual[i_, j_, k_] = x[idx[i_, j_, k_], j_, k_]

# compare 
print((out==out_manual).all())
# > tensor(True)

# 2D
x = torch.randn(i, j)
idx = torch.randint(0, i, (i, j))

# gather
out = torch.gather(x, dim=0, index=idx)

# manual
out_manual = torch.zeros(i, j)
for i_ in range(i):
    for j_ in range(j):
        out_manual[i_, j_] = x[idx[i_, j_], j_]

# compare 
print((out==out_manual).all())
# > tensor(True)