Code:
def foo(a, ind, len):
ind2 = ind + len
return a[ind:ind2]
torch.vmap(foo, in_dims=(0, 0, None))(a, c, 2)
RuntimeError: vmap: It looks like you’re calling .item() on a Tensor. We don’t support vmap over calling .item() on a Tensor, please try to rewrite what you’re doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.