RuntimeError: vmap: It looks like you're calling .item() on a Tensor

Code:

def foo(a, ind, len):
    ind2 = ind + len
    return a[ind:ind2]

torch.vmap(foo, in_dims=(0, 0, None))(a, c, 2)

RuntimeError: vmap: It looks like you’re calling .item() on a Tensor. We don’t support vmap over calling .item() on a Tensor, please try to rewrite what you’re doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.