How can I use a vector as index in (single_mean<total_mean)

single_mean[(single_mean<total_mean).detach()] = 0
In this case,total_mean=[4,6],and single_mean shape is 2x4x4. so how can I use 4 to choose the first channel,6 to index the second channel? thanks!