Averaging the tensor values based on indices

Given a tensor with shape (batch_size, seq_length, vocab_size)
It represents output from a transformer model and at each seq position gives a probability for each vocab token for all elements in the batch.

torch.max(logits, dim=2)

returns the values and corresponding indices
For example (batch size is 1),

torch.return_types.max(values=tensor([[ -7.8091,  -4.1528, -14.3334,  -3.2233,  -1.9405,  -3.9196,  -1.6090,
          -6.7052,  -7.1221,  -7.4800]], grad_fn=<MaxBackward0>), indices=tensor([[32099,    31,     7,    34,     5,     1,     1,     3,     3,     3]]))

I want to compute the average of maximum values before the occurrence of index 1
i.e before index 5

average([ -7.8091,  -4.1528, -14.3334,  -3.2233,  -1.9405])

for all elements in the batch.
What is the optimal way to achieve this