I’ve asked the same question on stackoverflow but not gotten an answer so far, maybe it is better suited for this site:
I have a pytorch tensor of shape
(n, 1, h, w) for arbitrary integers
w (in my specific case this array represents a batch of grayscale images of dimension
h x w).
I also have another tensor of shape
(m, 2) which maps every possible value in the first array (i.e. the first array can contain values from
m - 1) to some tuple of values. I would like to “apply” this mapping to the first array so that I obtain an array of shape
(n, 2, h, w).
I hope this is somewhat clear, I find this hard to express in words, here’s a code example (but note that that is not super intuitive either due to the four dimensional arrays involved):
import torch m = 18 # could also be arbitrary tensor with this shape with values between 0 and m - 1 a = torch.arange(m).reshape(2, 1, 3, 3) # could also be arbitrary tensor with this shape b = torch.stack((torch.arange(m), torch.arange(m)), dim=1) # I probably have to do this and the permute/reshape, but how? c = b.index_select(0, a.flatten()) # I don't know how to proceed from here but I would like to end up with: #[[[[ 0, 1, 2], # [ 3, 4, 5], # [ 6, 7, 8]], # # [[ 0, 1, 2], # [ 3, 4, 5], # [ 6, 7, 8]]], # # # [[[ 9, 10, 11], # [12, 13, 14], # [15, 16, 17]], # # [[ 9, 10, 11], # [12, 13, 14], # [15, 16, 17]]]]
How can I perform this transformation in an efficient manner? One thing that seems to work (but is very slow) is:
c = torch.stack([ b.index_select(0, a_.flatten()).reshape(3, 3, 2).permute(2, 0, 1) for a_ in a ])
This is somewhat frustrating since this is just a matter of correctly reshaping the result of
b.index_select(0, a_.flatten()) and it seems like this should be possible without for loops or the need to allocate additional memory (which I assume is what makes the call to
torch.stack so slow).