Selecting values in a multi dimension tensor using indices

I have a matrix A of shape (batch, dim1, dim2, K). K is the number of classes. I have another matrix B which contains the indices of the most probable values of A. This B is not derived by using argmax on A.

How can I collect matrix A’s values from matrix B’s indices? Let this matrix be C. I hope C’s shape will be the same as that of B.