Assume you have a tensor y of size M x 10 and another tensor x of size M. The values of both x and y are integers between 0 and 9.
I want to get a new tensor z of size M which is of the following form:
z[k]= y[k][x[k]]
Is it possible to do this without using a for loop?
You could use torch.gather
for this:
M = 10
y = torch.empty(M, 10, dtype=torch.long).random_(10)
x = torch.empty(M, dtype=torch.long).random_(10)
w = torch.gather(y, 1, x.unsqueeze(1))
w = w.view(-1)
z = torch.zeros(M, dtype=torch.long)
for k in range(M):
z[k] = y[k, x[k]]
print((z == w))
Thanks, indeed gather was exactly what I was looking for.