Given a multidimensional tensor, I want to find the max position for each row in the last dimension, and then extract 50 elements before and after each max position.

x = torch.rand(2, 4, 256)
peak, i = torch.max(x[...,50:-50], dim=-1)

But when I try to slice according to indices i, I get an error

x_subset = x[...,i-50 : i+50]
TypeError: only integer tensors of a single element can be converted to an index

Is there a good way to use different slices on different rows in a tensor? Thanks.