Vectorize for-loop - need to average slices of varying size

Not sure how to edit the original post so posting as comment instead.

Will use a simpler example.

sequence_output is a tensor of B * 3 * 2, where 3 is the max sequence length, and 2 is the number of features.

all_token_mapping is a tensor of B * 3 * 2, which contains a start and end index.

initial_reps is a tensor of num_nodes * 2, num_nodes is the sum of all the number of words (not subwords) in the different samples.

sequence_output = torch.arange(2*3*2).float().reshape(2, 3, 2)
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]]])
all_token_mapping = torch.tensor([[[0,0],[1,2],[-1,-1]], [[0,2],[-1,-1],[-1,-1]]])
tensor([[[ 0,  0],
         [ 1,  2],
         [-1, -1]],

        [[ 0,  2],
         [-1, -1],
         [-1, -1]]])
num_nodes = 0
for sample in all_token_mapping:
  for mapping in sample:
    if mapping[0] != -1:
      num_nodes += 1
3
initial_reps = torch.empty((num_nodes, 2), dtype=torch.float32)
current_idx = 0
for i, feature_tokens_mapping in enumerate(all_token_mapping):
    for j, token_mapping in enumerate(feature_tokens_mapping):
        if token_mapping[0] == -1: # reached the end for this particular sequence
            break
        initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True)                                           
        current_idx += 1
initial_reps
tensor([[0., 1.],
        [3., 4.],
        [8., 9.]])

In the example above, initial_reps[0] will be the mean of sequence_output[0][0:1], initial_reps[1] will be the mean of sequence_output[0][1:3], and initial_reps[2] will be the mean of sequence_output[1][0:3].

My current code will create an empty tensor of length num_nodes, and a for loop will calculate the values at each index, by checking token_mapping[0] and token_mapping[1] for the correct slice of sequence_output to average.

Is there a way to vectorize this code?

In addition, I have a list that holds the number of words for each sample. i.e. the sum of all the elements in the list == num_nodes