Local maximums of sub-tensors by index tensor

Hello everyone!

I have a tensor x of shape (1,n), and another index tensor d of shape (1,k). I’m trying to find the maximums of k sub-tensors x[0:d[0]], x[d[0]:d[1]], x[d[1]:d[2]], ..., x[d[-2]: d[-1]]. So the output is a tensor of shape (1,k) with k local maximums. I can implement a for loop, but that’s too slow. Can I do it in parallel in PyTorch?

3 Likes

I found the answer. There is a SegmentCSR function in torch_scatter that does the job.

2 Likes