I’ve asked the same question on stackoverflow but not gotten an answer so far, maybe it is better suited for this site:
I have a pytorch tensor of shape (n, 1, h, w)
for arbitrary integers n
, h
and w
(in my specific case this array represents a batch of grayscale images of dimension h x w
).
I also have another tensor of shape (m, 2)
which maps every possible value in the first array (i.e. the first array can contain values from 0
to m - 1
) to some tuple of values. I would like to “apply” this mapping to the first array so that I obtain an array of shape (n, 2, h, w)
.
I hope this is somewhat clear, I find this hard to express in words, here’s a code example (but note that that is not super intuitive either due to the four dimensional arrays involved):
import torch
m = 18
# could also be arbitrary tensor with this shape with values between 0 and m - 1
a = torch.arange(m).reshape(2, 1, 3, 3)
# could also be arbitrary tensor with this shape
b = torch.stack((torch.arange(m), torch.arange(m)), dim=1)
# I probably have to do this and the permute/reshape, but how?
c = b.index_select(0, a.flatten())
# I don't know how to proceed from here but I would like to end up with:
#[[[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]],
#
# [[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]]],
#
#
# [[[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]],
#
# [[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]]]]
How can I perform this transformation in an efficient manner? One thing that seems to work (but is very slow) is:
c = torch.stack([
b.index_select(0, a_.flatten()).reshape(3, 3, 2).permute(2, 0, 1)
for a_ in a
])
This is somewhat frustrating since this is just a matter of correctly reshaping the result of b.index_select(0, a_.flatten())
and it seems like this should be possible without for loops or the need to allocate additional memory (which I assume is what makes the call to torch.stack
so slow).