Vmap runtime error

I tried to optimize the process of derivative calculation in PyTorch using vmap. Here’s my original code:

dx_over_dt = torch.zeros_like(new_confs)
d2x_over_dt2 = torch.zeros_like(new_confs)
grad_mask = torch.ones(ts.shape[0]).to(device)
for i in range(num_atoms):
    for j in range(3):
        dx_over_dt[:, i, j] = torch.autograd.grad(
            new_confs[:, i, j], ts, create_graph=True,
            retain_graph=True, grad_outputs=grad_mask
        )[0]
        d2x_over_dt2[:, i, j] = torch.autograd.grad(
            dx_over_dt[:, i, j], ts, create_graph=True,
            grad_outputs=grad_mask
        )[0]

new_confs is the output of my model, and ts is one of the inputs to my model with ts.requires_grad = True. Then I modified the code to:

batch_size = ts.shape[0]
new_confs = new_confs.reshape(batch_size, -1).contiguous()
grad_mask = torch.eye(new_confs.shape[-1]).to(new_confs)
grad_mask = grad_mask.repeat(batch_size, 1, 1)

dx_over_dt = vmap(lambda v: torch.autograd.grad(
    new_confs, ts, create_graph=True,
    retain_graph=True, grad_outputs=v
)[0], in_dims=1, out_dims=1)(grad_mask)

d2x_over_dt2 = vmap(lambda v: torch.autograd.grad(
    dx_over_dt, ts, create_graph=True, grad_outputs=v
)[0], in_dims=1, out_dims=1)(grad_mask)

dx_over_dt = dx_over_dt.reshape(batch_size, -1, 3)
d2x_over_dt2 = d2x_over_dt2.reshape(batch_size, -1, 3)
new_confs = new_confs.reshape(batch_size, -1, 3)

However, no matter how I adjust the model structure and format, this code will definitely throw an error on the third call (the code successfully runs for two batches) with the following error:

I would like to know why this happens. Is this a bug in vmap, or is there an alternative way to write this code?