Identifying subwindow of a PyTorch tensor with certain properties

Hello everyone,

I’m trying to identify a sub-window in test such that the sub-window yields the maximum sum of the averages of all other subwindows. (Averages defined with respect to each subwindow, sum representing along dimension 2 (the number of features), and max across dimension 1, the original sequence length).

Here’s my code:

import torch
test = torch.randn(3, 2307, 2)

window_len = 50

unfolded = test.unfold(len(test.shape) - 2, window_len, 1)
_, indices = torch.max(torch.sum(torch.mean(unfolded, dim = -1), dim = 2), dim = 1)

## Should match this result: Size: torch.Size([3, 50, 2])
torch.stack([
    unfolded[0, :, :, :][indices[0], :, :].T,
    unfolded[1, :, :, :][indices[1], :, :].T, 
    unfolded[2, :, :, :][indices[2], :, :].T
])

Is there a more efficient and straightforward way to get this solution using PyTorch’s functions?

Thanks,
Wilson