Indexing a 2D tensor along multiple dimensions using another 2D tensor

I have a (N, V) matrix with predictions. Now I want to index this matrix with a (N, C) matrix to obtain a (N,C) matrix with the predictions corresponding to the indices of the second matrix.

Example:

a = torch.FloatTensor([[0.1, 0.3, 0.5, 0.1],
                       [0.2, 0.2, 0.3, 0.3]])

b = torch.LongTensor([[0, 1],
                      [2, 3]])

I want to do something like c = a[b] to obtain

tensor([[0.1, 0.3],
        [0.3, 0.3]])

I tried a.take(b) but this returns only values from the first row of a (ie. [[0.1, 0.3], [0.5, 0.1]]

EDIT:

The following does the job, but I don’t know if it messes with the computation graph.

c = Variable(torch.empty_like(b).float(), requires_grad=True)
for i in range(a.size(0)):
    c[i, :] = a[i, b[i]]

Something like the following works:

a = torch.FloatTensor([[0.1, 0.3, 0.5, 0.1],
                       [0.2, 0.2, 0.3, 0.3]])

a[ [[0, 0], [1, 1]] , [[0, 1], [2, 3]] ]

so it looks like you’d want to coax b into something that looks like [[0, 0], [1, 1]] , [[0, 1], [2, 3]]