# Inconsistent gather API

For 2d torch gather, specifying dim 0 means that for each entry along dim 0, indices are used to gather elements on dim 1. Say dim 0 is the batch dimension, the indices in this case will be used to gather along the features dimension.

For 3d torch gather, specifying dim 0 means that indices are directly applied to dim 0.

The 3d implementation actually generalizes to tensors of all dimensions. Why do we have the 2d gather special case?

Could you post examples for the 2D and 3D case which are showing this inconsistent behavior, please?

2D example:
source tensor = [batch, prob_buckets]
idx = [batch, num_samples]

Goal: get probability ratings for each of the idx by gather it from source_tensor.

In 2d case, the approach should be `torch.gather(source_tensor, 0, idx)`. This is because in order to sample along `prob_buckets` axies, we need to sample along the 0 dimension. My understanding is that `dim=0` means that for each entry `[prob_buckets]` along the `batch` dimension gets sampled by the `idx` tensor.

My understanding could be incorrect but if this is the case, then the 3d example with `dim=0` would sample along the batch dimension for each `prob_buckets`.

I’m still unsure what exactly is inconsistent.
As described in the docs the gather operation would be applied as:

``````out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
``````

for a 3D case (and the same logic would still work for the 2D case by removing the `k` dimension).

Here is a small example:

``````# 3D
i, j, k = 2, 3, 4
x = torch.randn(i, j, k)
idx = torch.randint(0, i, (i, j, k))

# gather
out = torch.gather(x, dim=0, index=idx)

# manual
out_manual = torch.zeros(i, j, k)
for i_ in range(i):
for j_ in range(j):
for k_ in range(k):
out_manual[i_, j_, k_] = x[idx[i_, j_, k_], j_, k_]

# compare
print((out==out_manual).all())
# > tensor(True)

# 2D
x = torch.randn(i, j)
idx = torch.randint(0, i, (i, j))

# gather
out = torch.gather(x, dim=0, index=idx)

# manual
out_manual = torch.zeros(i, j)
for i_ in range(i):
for j_ in range(j):
out_manual[i_, j_] = x[idx[i_, j_], j_]

# compare
print((out==out_manual).all())
# > tensor(True)
``````