Indexing 3d tensor along its last axis

Hello,

Given a 3d tensor x and a 2d Long tensor idx, I want to compute the out tensor defined as

out[i, j] = x[i, j, idx[i, j]]

Any idea how I could do that efficiently?

Thanks

Found the solution. For anyone interested:

idx = idx.unsqueeze(2)
out = x.gather(2, idx)
out = out.squeeze(2)