Different slicing of a tensor at different rows

Given a multidimensional tensor, I want to find the max position for each row in the last dimension, and then extract 50 elements before and after each max position.

x = torch.rand(2, 4, 256)
peak, i = torch.max(x[...,50:-50], dim=-1)

But when I try to slice according to indices i, I get an error

x_subset = x[...,i-50 : i+50]

TypeError: only integer tensors of a single element can be converted to an index

Is there a good way to use different slices on different rows in a tensor? Thanks.

Hi hzzt!

gather() will do this if you prepare an index tensor and then offset it with
i (the argmax() of x):

>>> import torch
>>> print (torch.__version__)
2.1.2
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> x = torch.rand (2, 4, 256)
>>> peak, i = torch.max (x[..., 50:-50], dim = -1)
>>>
>>> ind = torch.arange (100).repeat (*x.size()[0:2], 1)
>>> x_subset = x.gather (2, ind + i.unsqueeze (-1))
>>>
>>> peak
tensor([[0.9945, 0.9997, 0.9907, 0.9982],
        [0.9968, 0.9943, 0.9872, 0.9989]])
>>> x_subset[:, :, 50]
tensor([[0.9945, 0.9997, 0.9907, 0.9982],
        [0.9968, 0.9943, 0.9872, 0.9989]])

Best.

K. Frank

1 Like