Fetch values of 3d matrix with 2d tensor

Hi, I have a 3d Tensor M of shape bs x N x N. I also have 2d index tensor I of size bs x N. The Matrix M represents distances between 2 sets of points of size N, and does so for bs batches. Matrix I is a bipartite matching between the 2 sets, where I[b][i] = j means that for batch b, i is matched with j. I need to fetch the distances of for all point matchings, and do so across all the batches. A way to do it -badly- would be something like this :

total_distance = 0
for batch in range(bs):
index = I[batch] # matching for the batch : 1 x N vector
distances_for_this_batch = M[batch, torch.arange(0,N), index].sum()
total_distance += distances_for_this_batch

Is there a better/faster way to proceed ?