Apply mapping over singleton dimension of tensor

I’ve asked the same question on stackoverflow but not gotten an answer so far, maybe it is better suited for this site:

I have a pytorch tensor of shape (n, 1, h, w) for arbitrary integers n, h and w (in my specific case this array represents a batch of grayscale images of dimension h x w).

I also have another tensor of shape (m, 2) which maps every possible value in the first array (i.e. the first array can contain values from 0 to m - 1) to some tuple of values. I would like to “apply” this mapping to the first array so that I obtain an array of shape (n, 2, h, w).

I hope this is somewhat clear, I find this hard to express in words, here’s a code example (but note that that is not super intuitive either due to the four dimensional arrays involved):

import torch

m = 18
    
# could also be arbitrary tensor with this shape with values between 0 and m - 1
a = torch.arange(m).reshape(2, 1, 3, 3)
    
# could also be arbitrary tensor with this shape
b = torch.stack((torch.arange(m), torch.arange(m)), dim=1)
    
# I probably have to do this and the permute/reshape, but how?
c = b.index_select(0, a.flatten())
    
# I don't know how to proceed from here but I would like to end up with:

#[[[[ 0,  1,  2],
#   [ 3,  4,  5],
#   [ 6,  7,  8]],
#
#  [[ 0,  1,  2],
#   [ 3,  4,  5],
#   [ 6,  7,  8]]],
#
#
# [[[ 9, 10, 11],
#   [12, 13, 14],
#   [15, 16, 17]],
#
#  [[ 9, 10, 11],
#   [12, 13, 14],
#   [15, 16, 17]]]]

How can I perform this transformation in an efficient manner? One thing that seems to work (but is very slow) is:

c = torch.stack([
    b.index_select(0, a_.flatten()).reshape(3, 3, 2).permute(2, 0, 1)
    for a_ in a
])

This is somewhat frustrating since this is just a matter of correctly reshaping the result of b.index_select(0, a_.flatten()) and it seems like this should be possible without for loops or the need to allocate additional memory (which I assume is what makes the call to torch.stack so slow).

I think I might have made a mistake while measuring the execution time of this the first time round, it doesn’t seem to be that slow after all. But it would still be nice to have a more readable way of doing the same thing.