I am attempting to process a tensor based on a list of segments. Each segment in the tensor will undergo a different operation. I have tried using the vmap function to solve this problem, but it is resulting in an error: “RuntimeError: vmap: It looks like you’re calling .item() on a Tensor.”. Here is a simplified version of my code:
import torch segment_size = torch.randint(1, 6, size=(6,)) idx_end = torch.cumsum(segment_size, dim=0) idx_start = torch.roll(idx_end, 1) idx_start = 0 idx_start_end = torch.cat((idx_start.unsqueeze(0), idx_end.unsqueeze(0)), dim=0) x = torch.randn(idx_end[-1]) def calculate(start_end, x): start = start_end end = start_end selected_x = x[start:end] selected_x = 0 return selected_x torch.vmap(calculate, in_dims=(1,None))(idx_start_end,x)
I would like to know how to resolve this issue or if there is a more efficient way to implement it.