How to batch torch.roll with torch.vmap?

I want to roll the tensor at a list of shift. Here is the toy example. Can it be vectorized with torch.vmap?

import torch

def custom_roll(tensor, shifts):
    shifted_rows = [torch.cat((row[-shift:], row[:-shift])) for row, shift in zip(tensor, shifts)]
    return torch.stack(shifted_rows)

tensor = torch.tensor([[1, 2, 3, 4, 5],
                       [6, 7, 8, 9, 10],
                       [11, 12, 13, 14, 15]])

shifts = torch.tensor([1, 2, 3])

result = custom_roll(tensor, shifts)
print(result)

1 Like