I am attempting to process a tensor based on a list of segments. Each segment in the tensor will undergo a different operation. I have tried using the vmap function to solve this problem, but it is resulting in an error: “RuntimeError: vmap: It looks like you’re calling .item() on a Tensor.”. Here is a simplified version of my code:
import torch
segment_size = torch.randint(1, 6, size=(6,))
idx_end = torch.cumsum(segment_size, dim=0)
idx_start = torch.roll(idx_end, 1)
idx_start[0] = 0
idx_start_end = torch.cat((idx_start.unsqueeze(0), idx_end.unsqueeze(0)), dim=0)
x = torch.randn(idx_end[-1])
def calculate(start_end, x):
start = start_end[0]
end = start_end[1]
selected_x = x[start:end]
selected_x = 0
return selected_x
torch.vmap(calculate, in_dims=(1,None))(idx_start_end,x)
I would like to know how to resolve this issue or if there is a more efficient way to implement it.