How can I compute my own embedding lookup layer/index_select in 3 dimensions?

I’ve created some reconstructed embeddings of size [batch_size, embedding_dimension, sequence_length], and I want to be able to index_select those embeddings based on a lookup matrix of larger dimensions.

For example, I have an embedding matrix of shape [12, 4, 13], where each batch has its own embeddings which have been calculated:

c= Variable containing:
(0 ,.,.) = 

Columns 0 to 8 
   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(1 ,.,.) = 

Columns 0 to 8 
   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(2 ,.,.) = 

Columns 0 to 8 
   0.4401  0.4401  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.2005  0.2005  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.1747 -0.1747  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.6075 -0.6075  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(3 ,.,.) = 

Columns 0 to 8 
   0.3171  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.4901  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.1445  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.4877  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(4 ,.,.) = 

Columns 0 to 8 
  -0.5359  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.4196  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.1970  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0173  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(5 ,.,.) = 

Columns 0 to 8 
   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(6 ,.,.) = 

Columns 0 to 8 
   0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(7 ,.,.) = 

Columns 0 to 8 
   0.3171 -0.0451  0.3721  0.3721 -0.0451  0.4610  0.2726  0.4610  0.2726
 -0.4901 -0.4598 -0.4402 -0.4402 -0.4598 -0.6407 -0.3215 -0.6407 -0.3215
 -0.1445  0.9977  0.2344  0.2344  0.9977 -0.0442 -0.0341 -0.0442 -0.0341
 -0.4877  0.0229  0.0757  0.0757  0.0229  0.1043  0.2794  0.1043  0.2794

Columns 9 to 12 
   0.4610  0.4610  0.0000  0.0000
 -0.6407 -0.6407  0.0000  0.0000
 -0.0442 -0.0442  0.0000  0.0000
  0.1043  0.1043  0.0000  0.0000

(8 ,.,.) = 

Columns 0 to 8 
   0.3171  0.2738  0.2684  0.2222  0.7377  0.0002  0.0002  0.0000  0.0000
 -0.4901  0.0548 -0.3797 -0.8453 -0.8063 -0.6694 -0.6694  0.0000  0.0000
 -0.1445  0.0201  0.5294  1.0955  0.1553  0.1219  0.1219  0.0000  0.0000
 -0.4877  0.5877  0.0506  0.5659 -0.6621 -0.0036 -0.0036  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(9 ,.,.) = 

Columns 0 to 8 
   0.3171  0.4198  0.4198 -0.6232  0.6900  0.8220 -0.5885  0.3171  0.0000
 -0.4901  0.1644  0.1644  0.0501 -0.0060 -0.1977  0.3468 -0.4901  0.0000
 -0.1445 -0.2152 -0.2152 -0.0906  0.0158 -0.2144  0.5921 -0.1445  0.0000
 -0.4877 -0.2854 -0.2854  0.2168 -0.2466 -0.1935  0.1861 -0.4877  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(10,.,.) = 

Columns 0 to 8 
   0.3395  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.3312  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.2439  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
  0.2298  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000

(11,.,.) = 

Columns 0 to 8 
   0.3171  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.4901  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.1445  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.4877  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 9 to 12 
   0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 12x4x13]

I now need to re-align these embeddings based on a lookup matrix of indices of shape [12,50,3] (where I would like a row-stacked 3 embeddings per row per batch based on some indices, such as:

icd_positions = Variable containing:
(0 ,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11

(1 ,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11

(2 ,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11
...

(9 ,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11

(10,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11

(11,.,.) = 
  11  11  11
  11  11  11
  11  11  11
     ⋮      
  11  11  11
  11  11  11
  11  11  11
[torch.LongTensor of size 12x50x3]

To get an output of [12, 50, 3, 4]

Note that each batch has it’s own set of indices corresponding to its proper dimension in the embedding matrix (e.g. indices 0-13 in icd_dimensions[0:1,:,:] correspond to the embeddings in c[0:1,:,:], and the same indices in icd_dimensions[1:2,:,:] correspond to c[1:2,:,:] etc. (so they are not unique).

How can I do a lookup? I’ve tried something along the lines of

torch.cat([ torch.index_select(a, 1, i).unsqueeze(0) for a, i in zip(c, icd_positions) ])

which I think should be close but it’s not working.

Nevermind, it’s resolved! You can concatenate across dimensions using more nested list comprehensions, e.g. torch.cat([torch.cat([torch.index_select(d, 1, i).unsqueeze(0) for i in tst]).unsqueeze(0) for d,tst in zip(c,icd_positions)])