Hi hzzt!
gather() will do this if you prepare an index tensor and then offset it with
i (the argmax() of x):
>>> import torch
>>> print (torch.__version__)
2.1.2
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> x = torch.rand (2, 4, 256)
>>> peak, i = torch.max (x[..., 50:-50], dim = -1)
>>>
>>> ind = torch.arange (100).repeat (*x.size()[0:2], 1)
>>> x_subset = x.gather (2, ind + i.unsqueeze (-1))
>>>
>>> peak
tensor([[0.9945, 0.9997, 0.9907, 0.9982],
[0.9968, 0.9943, 0.9872, 0.9989]])
>>> x_subset[:, :, 50]
tensor([[0.9945, 0.9997, 0.9907, 0.9982],
[0.9968, 0.9943, 0.9872, 0.9989]])
Best.
K. Frank