Given a tensor with shape (batch_size, seq_length, vocab_size)
It represents output from a transformer model and at each seq position gives a probability for each vocab token for all elements in the batch.
torch.max(logits, dim=2)
returns the values and corresponding indices
For example (batch size is 1),
torch.return_types.max(values=tensor([[ -7.8091, -4.1528, -14.3334, -3.2233, -1.9405, -3.9196, -1.6090,
-6.7052, -7.1221, -7.4800]], grad_fn=<MaxBackward0>), indices=tensor([[32099, 31, 7, 34, 5, 1, 1, 3, 3, 3]]))
I want to compute the average of maximum values before the occurrence of index 1
i.e before index 5
average([ -7.8091, -4.1528, -14.3334, -3.2233, -1.9405])
for all elements in the batch.
What is the optimal way to achieve this