I want to roll the tensor at a list of shift. Here is the toy example. Can it be vectorized with torch.vmap?
import torch
def custom_roll(tensor, shifts):
shifted_rows = [torch.cat((row[-shift:], row[:-shift])) for row, shift in zip(tensor, shifts)]
return torch.stack(shifted_rows)
tensor = torch.tensor([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
shifts = torch.tensor([1, 2, 3])
result = custom_roll(tensor, shifts)
print(result)