Vectorize for-loop - need to average slices of varying size

I am trying to average subword embeddings to form a word-level representation. Each word has a corresponding start and end index, indicating which subwords make up that word.

sequence_output is a tensor of B * 384 * 768, where 384 is the max sequence length, and 768 is the number of features.

all_token_mapping is a tensor of B * 384 * 2, which contains a start and end index. It is padded with [-1, -1].

initial_reps is a tensor of num_nodes * 768, num_nodes is the sum of all the number of words (not subwords) in the different samples.

initial_reps = torch.empty((num_nodes, 768), dtype=torch.float32)
current_idx = 0
for i, feature_tokens_mapping in enumerate(all_token_mapping):
    for j, token_mapping in enumerate(feature_tokens_mapping):
        if token_mapping[0] == -1: # reached the end for this particular sequence
            break
        initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True)                                           
        current_idx += 1

My current code will create an empty tensor of length num_nodes, and a for loop will calculate the values at each index, by checking token_mapping[0] and token_mapping[1] for the correct slice of sequence_output to average.

Is there a way to vectorize this code?

In addition, I have a list that holds the number of words for each sample. i.e. the sum of all the elements in the list == num_nodes

Thank you.

Not sure how to edit the original post so posting as comment instead.

Will use a simpler example.

sequence_output is a tensor of B * 3 * 2, where 3 is the max sequence length, and 2 is the number of features.

all_token_mapping is a tensor of B * 3 * 2, which contains a start and end index.

initial_reps is a tensor of num_nodes * 2, num_nodes is the sum of all the number of words (not subwords) in the different samples.

sequence_output = torch.arange(2*3*2).float().reshape(2, 3, 2)
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]]])
all_token_mapping = torch.tensor([[[0,0],[1,2],[-1,-1]], [[0,2],[-1,-1],[-1,-1]]])
tensor([[[ 0,  0],
         [ 1,  2],
         [-1, -1]],

        [[ 0,  2],
         [-1, -1],
         [-1, -1]]])
num_nodes = 0
for sample in all_token_mapping:
  for mapping in sample:
    if mapping[0] != -1:
      num_nodes += 1
3
initial_reps = torch.empty((num_nodes, 2), dtype=torch.float32)
current_idx = 0
for i, feature_tokens_mapping in enumerate(all_token_mapping):
    for j, token_mapping in enumerate(feature_tokens_mapping):
        if token_mapping[0] == -1: # reached the end for this particular sequence
            break
        initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True)                                           
        current_idx += 1
initial_reps
tensor([[0., 1.],
        [3., 4.],
        [8., 9.]])

In the example above, initial_reps[0] will be the mean of sequence_output[0][0:1], initial_reps[1] will be the mean of sequence_output[0][1:3], and initial_reps[2] will be the mean of sequence_output[1][0:3].

My current code will create an empty tensor of length num_nodes, and a for loop will calculate the values at each index, by checking token_mapping[0] and token_mapping[1] for the correct slice of sequence_output to average.

Is there a way to vectorize this code?

In addition, I have a list that holds the number of words for each sample. i.e. the sum of all the elements in the list == num_nodes

I think the encoding with “break” makes it a bit hard.
If you encode sequence ids instead (can be done with comparison and cumsum if you don’t have it already), you can use index_add or some scatter_add implementation.
Last I looked the public implementations for PyTorch were not terribly optimized but what can you do.

1 Like

Thanks to tom, I found out that scatter_add exists, and from there I found torch_scatter’s segment_coo
Here’s my solution now:

    initial_reps_list = []
    for i, sample_output in enumerate(sequence_output):
        token_mapping = all_token_mapping[i]
        token_mapping = token_mapping[token_mapping != -1]
        non_padded_outputs = sample_output[:num_bert_tokens[i]]
        initial_reps_list.append(torch_scatter.segment_coo(non_padded_outputs, token_mapping, reduce="mean"))
    initial_reps = torch.cat(initial_reps_list)

token_mapping is a list of indices in ascending order up to the max sequence length, padded with -1. I loop through the batch, for each sample, I get the token mapping, and only keep the non-negative indices.
num_bert_tokens is a list that holds, for each sample, the number of tokens (no padding). I get the non-padded outputs, use segment_coo to reduce them according to the token_mapping, and append them all to a list.
After the loop, I concatenate all the tensors in the list together.
The method segment_coo reduces all values from the src tensor into out at the indices specified in the index tensor along the last dimension of index. More details can be found at: Segment COO — pytorch_scatter 2.0.6 documentation

Code runs much faster now :slight_smile: