What is an efficient way to average over subtensors of a torch.nn.Parameter?

I have a torch.nn.Parameter P of shape (10, 50). (50000 replaced by 10 for the sake of easier explainability).
Conceptually, each index of the first dimension corresponds to an ‘embedding’ of size (1, 50).
My model has two ‘paths’ to the output, path one of which uses all of the 10 embeddings individually, whereas the second path is supposed to use averages over certain neighboring embeddings. My problem is with the second path, I only mentioned the first so it is clear why I need to ‘average’ at all.
I have a dictionary

d = dict({0:[0], 1:[2,3,4], 2:[5,6,7,8,9]})

in which each key serves as an index i of a group of embeddings belonging together (i.e. which should be averaged), and the values are the indices of the embeddings which I would like to take the average of.

My forward function gets passed a batch of indices, each of which corresponds to a key of d. Now I want to get the averaged embedding for each index in the batch, to further process it in my forward function.

I tried the following:

averaged_embeddings = []
for id in batch:
    emb_idxs = d[id]
    emb_count = len(emb_idxs)
    embeddings = torch.narrow(P, 0, emb_idxs[0], emb_count)
    average = torch.sum(embeddings, dim=0) / emb_count
    average = average.view(1,5)
    averaged_embeddings.append(average)
batch_of_averaged_embeddings = torch.cat(averaged_embeddings,0)

The problem is that this seems to be incredibly inefficient. My batch sizes are in the 10000s, and in one epoch I process more than 1000 batches. A single epoch would now take ~2.5 hours instead of 30 seconds, which was the time without using these averages.

Is there a better way to accomplish what I want to do? For a given index from the batch, it’s always the same embeddings which are averaged. So maybe I could first check which groups appear in the batch, then do the averaging once, and fill the tensor with copies of the result. But is there more I can do to make this faster, maybe something completely different?
Any help would be appreciated.

Seems like you only need matrix multiplication here, Weights @ P = averages. Weights of shape (batch, num_embeddings) should replace your d dict and contain “contributions”, that are either zero or 1/emb_count, to get masked averages.

Thank you for the suggestion. I have now tried to do the following:

  1. Initialize a weight tensor and
  2. adjust the forward function accordingly. But now I run into memory problems (see end of post):

Initializing weights:

group_ids = d.keys()
weights_shape = (len(group_ids), P.shape[0])
weights = torch.zeros(size=weights_shape, dtype=torch.double, requires_grad=False, device='cuda')

for group_id in group_ids:
    number_of_relevant_embs = len(d[group_id])
    id_of_first_relevant_emb = d[group_id][0]
    id_of_last_relevant_emb = id_of_first_relevant_emb+number_of_relevant_embs
    weights[group_id][id_of_first_relevant_emb:id_of_last_relevant_emb] = 1/number_of_relevant_embs

Adjusted part of forward function (given batch of group_ids):

weight_batch = weights[batch]
batch_of_averaged_embeddings = torch.matmul(weight_batch, embeddings)

Unfortunately, since weights_shape is around (30000,40000), this approach already fails when trying to create the weights torch.zeros tensor, with a CUDA out of memory RuntimeError (wants to allocate around 10GB for this, which is almost the entire memory of my GPU).

Is what I did what you meant in the first place?

Yes. With that shape, maybe sparse weights matrix will work. Another option is doing matrix multiplication block by block, but that may not help if you’re recording P gradients for this. Your initial code can also be reformulated as doing matmul row by row (dot product) - if you use JIT, remove dict lookup and list append (without gradients - you can write results to preallocated tensor, with gradients - torch.cat may be unavoidable), it should be reasonably fast…

I tried the sparse matrix approach today, but I couldn’t get it to work because apparently there is no way to extract a batch from it like so:
weight_batch = sparse_weight_matrix[batch]
which I would need to get one set of weights per group_id in the batch.

I have rather ment weights matrix created per batch. Then won’t sparse_weights.mm(P)[batch] work? It is kinda weird, as one mm() is really needed per epoch, but training with batches complicates things…

Also, you may try batch_map.smm(sparse_weights).mm(P), where batch_map - sparse binary tensor of shape (batch_size,len(group_ids)). That’s just a clunky way to index sparse_weights, not sure if there are better ways…

I tried three more things.

sparse_weights.mm(P)[batch]

as you suggested does indeed work. One epoch now takes ~3min20secs instead of the original 30 seconds, which is still not great, but better than the other options I tried. I have to determine whether this is really worth it.

The approach with

batch_map.smm(sparse_weights).mm(P)

did not work, because apparently .smm / .spmm are no longer available. There exists some external library for handling sparse matrix multiplication, but apparently there is a problem with it properly handling gradients.

One other thing I tried is the following: I realized that of the 30000 groups, only approximately 2000 are occurring in one batch of size ~35000.
So given a batch, I extracted these unique indices into a list, converted this list to a dense one-hot matrix, and multiplied it with the sparse group_weights matrix to get only the relevant weights. The resulting dense matrix multiplied with P gave me the 2000 relevant averaged embeddings, from which I could extract the batch of averaged embeddings corresponding to batch. Basically I though that by not calculating all averaged embeddings I could save time.
Unfortunately, contrary to what I had hoped, this approach took much longer again, with 30minutes per epoch.

Your problem is probably elementwise assignments to cuda tensor, like you do with dense weights above. You should do these on cpu tensor, for one-hot you can even use bool dtype, then do .to(device=“cuda”, dtype=XXX). Actually, I don’t know if such sparse/dense matmul is faster on cuda, especially with float64.

Sorry, where do I do these elementwise assignments you’re speaking of?

assignments to slices, to be exact; anything like:

`weights[group_id][id_of_first_relevant_emb:id_of_last_relevant_emb] = 1/number_of_relevant_embs`

in a loop, such statements produce a lot of cuda commands