So, I have time-series data where there could be multiple categories per time-point. I would like to use an embedding lookup table to get each individual categories embedding, and take the average of them for each time point. I dug a little, and found a neat little EmbeddingBag class seemingly made exactly for this type of setup.
So, below, I have 4 timepoints, each composed of a max number of 2 categories. The 2nd-to-last datapoint is half-padded, and the last datapoint is completely padded. I use the ‘max’ functionality to ensure that the function is interpreting my data correctly.
embedding_sum = nn.EmbeddingBag(10, 3, mode='max')
input = torch.LongTensor([[1,2], [1,1], [2,2], [0,0]])
print(embedding_sum(input))
tensor([[-0.2409, 0.4761, 0.0832],
[-1.9867, 0.4761, -0.5098],
[-0.2409, 0.0551, 0.0832],
[-0.5566, -0.3860, 0.6931]], grad_fn=<EmbeddingBagBackward>)
This works just fine. However, I do not know how to extend this to a batched case. For example, where I have something like two data points or in code form…
embedding_sum = nn.EmbeddingBag(10, 3, mode='max')
input = torch.LongTensor([[[1,2], [1,1], [2,2], [0,0]], [[2,2], [1,1], [0,0], [0,0]]])
print(embedding_sum(input))
ValueError: input has to be 1D or 2D Tensor, but got Tensor of dimension 3
I get a value error. Any ideas on what I’m missing?