How to perform `index_put_` in func of vmap?

I want to complete the assignment of the lower triangular matrix in the function and extend it to a batch of data using vmap:

import torch
from torch.func import vmap


def fn(l):
    L = torch.zeros(*l.shape[:-1], 2, 2)
    L[..., [0, 1, 1], [0, 0, 1]] = l
    return L


if __name__ == '__main__':
    L1 = fn(torch.tensor([1., 2., 3.]))
    print(L1)

    Ls1 = fn(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
    print(Ls1)

    Ls2 = vmap(fn)(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
    print(Ls2)

When called separately, it performs well. However, when using vmap, an error is reported

L1 tensor([[1., 0.],
        [2., 3.]])
Ls1 tensor([[[1., 0.],
         [2., 3.]],

        [[4., 0.],
         [5., 6.]]])
Traceback (most recent call last):
  File "/Users/frank/Projects/neural_lagrangian/neural_lagrangian/models/t.py", line 18, in <module>
    Ls2 = vmap(fn)(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
  File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/Users/frank/Projects/neural_lagrangian/neural_lagrangian/models/t.py", line 7, in fn
    L[..., [0, 1, 1], [0, 0, 1]] = l
RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_. If said operator is being called inside the PyTorch framework, please file a bug report instead.

Due to the need to use jacrev to calculate the Jacobian matrix in subsequent tasks, vmap is necessary. I tried using the out-of-place version of the index_put to solves this problem, but index_put does not support ellipses.

May I ask if you have any good suggestions? Thank you very much!