How to use embedding layers for batched multilabel categories?

So, I have time-series data where there could be multiple categories per time-point. I would like to use an embedding lookup table to get each individual categories embedding, and take the average of them for each time point. I dug a little, and found a neat little EmbeddingBag class seemingly made exactly for this type of setup.

So, below, I have 4 timepoints, each composed of a max number of 2 categories. The 2nd-to-last datapoint is half-padded, and the last datapoint is completely padded. I use the ‘max’ functionality to ensure that the function is interpreting my data correctly.

embedding_sum = nn.EmbeddingBag(10, 3, mode='max')
input = torch.LongTensor([[1,2], [1,1], [2,2], [0,0]])
print(embedding_sum(input))

tensor([[-0.2409,  0.4761,  0.0832],
        [-1.9867,  0.4761, -0.5098],
        [-0.2409,  0.0551,  0.0832],
        [-0.5566, -0.3860,  0.6931]], grad_fn=<EmbeddingBagBackward>)

This works just fine. However, I do not know how to extend this to a batched case. For example, where I have something like two data points or in code form…

embedding_sum = nn.EmbeddingBag(10, 3, mode='max')
input = torch.LongTensor([[[1,2], [1,1], [2,2], [0,0]], [[2,2], [1,1], [0,0], [0,0]]])
print(embedding_sum(input))

ValueError: input has to be 1D or 2D Tensor, but got Tensor of dimension 3

I get a value error. Any ideas on what I’m missing?