I have 8 images (say X) represented as 8x3x100x100, where 3 is the number of channels and 100x100 is the size of each image. I have a vector “index” - indicates which channel to be taken from each image, e.g., “index” = [1,1,2,0,…]. After selecting the appropriate channels from each image the output should have a size of 8x1x100x100. How can I do it?
I tried X.index_select(1, index), but it gives me 8x8x100x100 as the output.