Gathering in multidimensional tensors: weird brodcasting

I have a 5 dimensional tensor S of shape (bs,a,b,c,d) which can be seen as a batch of (a x b) 2D matrices of (c x d) 2D matrices. Let say bs =1 and a = b = c = d = 3 for simplicity.

I have a shape (b) vector (let say ind = [2,0,1]) that contains indices to columns of the inner c x d matrices. For each row i and column j of the outer a x b matrix, I need to gather the column ind[j] of the c x d matrix in an (a x b x c) tensor.

Example with S

tensor([[[[[   0,    0,    0],
           [   0,    0,    0],
           [   0,    0,    0]],

          [[1211, 1212, 1213],
           [1221, 1222, 1223],
           [1231, 1232, 1233]],

          [[1311, 1312, 1313],
           [1321, 1322, 1323],
           [1331, 1332, 1333]]],


         [[[1211, 1221, 1231],
           [1212, 1222, 1232],
           [1213, 1223, 1233]],

          [[   0,    0,    0],
           [   0,    0,    0],
           [   0,    0,    0]],

          [[2311, 2312, 2313],
           [2321, 2322, 2323],
           [2331, 2332, 2333]]],


         [[[1311, 1321, 1331],
           [1312, 1322, 1332],
           [1313, 1323, 1333]],

          [[2311, 2321, 2331],
           [2312, 2322, 2332],
           [2313, 2323, 2333]],

          [[   0,    0,    0],
           [   0,    0,    0],
           [   0,    0,    0]]]]])

I found that

S[:,torch.arange(3)[None,:],torch.arange(3)[:,None],ind[None,:]]

does the job for this specific case

tensor([[[[   0,    0,    0],
          [1211, 1221, 1231],
          [1312, 1322, 1332]],

         [[1231, 1232, 1233],
          [   0,    0,    0],
          [2312, 2322, 2332]],

         [[1331, 1332, 1333],
          [2311, 2312, 2313],
          [   0,    0,    0]]]])

does the job but I don’t get why and how broadcasting works here. My indices have all shape (1,3) or (3,1) and should broadcast to a (3,3) shape but the result has shape (3,3,3) as required. I don’t get why.

What would be a general solution for arbitrary bs,a,b,c,d?

T.