Given a tensor with shape (batch_size, seq_length, vocab_size)

It represents output from a transformer model and at each seq position gives a probability for each vocab token for all elements in the batch.

```
torch.max(logits, dim=2)
```

returns the values and corresponding indices

For example (batch size is 1),

```
torch.return_types.max(values=tensor([[ -7.8091, -4.1528, -14.3334, -3.2233, -1.9405, -3.9196, -1.6090,
-6.7052, -7.1221, -7.4800]], grad_fn=<MaxBackward0>), indices=tensor([[32099, 31, 7, 34, 5, 1, 1, 3, 3, 3]]))
```

I want to compute the average of maximum values before the occurrence of index 1

i.e before index 5

```
average([ -7.8091, -4.1528, -14.3334, -3.2233, -1.9405])
```

for all elements in the batch.

What is the optimal way to achieve this