I have a torch.nn.Parameter P
of shape (10, 50). (50000 replaced by 10 for the sake of easier explainability).
Conceptually, each index of the first dimension corresponds to an ‘embedding’ of size (1, 50).
My model has two ‘paths’ to the output, path one of which uses all of the 10 embeddings individually, whereas the second path is supposed to use averages over certain neighboring embeddings. My problem is with the second path, I only mentioned the first so it is clear why I need to ‘average’ at all.
I have a dictionary
d = dict({0:[0], 1:[2,3,4], 2:[5,6,7,8,9]})
in which each key serves as an index i of a group of embeddings belonging together (i.e. which should be averaged), and the values are the indices of the embeddings which I would like to take the average of.
My forward function gets passed a batch
of indices, each of which corresponds to a key of d
. Now I want to get the averaged embedding for each index in the batch, to further process it in my forward function.
I tried the following:
averaged_embeddings = []
for id in batch:
emb_idxs = d[id]
emb_count = len(emb_idxs)
embeddings = torch.narrow(P, 0, emb_idxs[0], emb_count)
average = torch.sum(embeddings, dim=0) / emb_count
average = average.view(1,5)
averaged_embeddings.append(average)
batch_of_averaged_embeddings = torch.cat(averaged_embeddings,0)
The problem is that this seems to be incredibly inefficient. My batch sizes are in the 10000s, and in one epoch I process more than 1000 batches. A single epoch would now take ~2.5 hours instead of 30 seconds, which was the time without using these averages.
Is there a better way to accomplish what I want to do? For a given index from the batch, it’s always the same embeddings which are averaged. So maybe I could first check which groups appear in the batch, then do the averaging once, and fill the tensor with copies of the result. But is there more I can do to make this faster, maybe something completely different?
Any help would be appreciated.