How to use EmbeddingBag with uneven 3D data?

I have a dataset of B x T x C, where B is batches, T is timestep (uneven), and C is characters (uneven). I would like to use EmbeddingBag to get a mean-embedding of each timestep of characters.

For example, lets say I have three datapoints in my batch:

  1. [[], [0, 4], [1, 1], [5]]
  • This has 4 time steps, and 0, 2, 2, 1, respectively, characters in each timestep.
  1. [[1], [2, 3]]
  • This has 2 time steps, and 1, 2, respectively, characters in each timestep.
  1. [[2, 4, 5], []]
  • This has 2 timesteps, and 3, 1, respectively, characters in each timestep

So let’s init that:
all_tensors = [[[], [0, 4], [1, 1], [5]], [[1], [2,3]], [[2, 4, 5], []]]

And I know this is what I want my embedder to look like:
embedder = torch.nn.EmbeddingBag(num_embeddings = 6, embedding_dim = 2, mode = 'mean')

And…this is where I am stuck. Is there a good tutorial for how this problem should be approached?

Edit: Think I got a bit closer…

def pad_array(base_input):
    for index1, datapoint in enumerate(base_input):
        base_input[index1] = torch.LongTensor(np.asarray([np.pad(a, (0, 5 - len(a)),  'constant', 
                                                                 constant_values=0) for a in datapoint]))
    return base_input

all_tensors = [[[], [0, 4], [1, 1], [5]], [[1], [2,3]], [[2, 4, 5], []]]
paddedchar_tensors = pad_array(all_tensors)
paddedchar_tensors = rnn_utils.pad_sequence(padded_tensors, batch_first=True)

This gives me paddedcode_tensors as:

tensor([[[0, 0, 0, 0, 0],
         [0, 4, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [5, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [2, 3, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[2, 4, 5, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]])

But once again, I am stuck, running this through the EmbeddingBag gives me this error: ValueError: input has to be 1D or 2D Tensor, but got Tensor of dimension 3

1 Like

@Abhishaike_Mahajan did you ever find a solution to this problem? I’m looking for the solution for a similar application I’m working on and would really appreciate any pointers if you’ve had success. Thank you!