I want to complete the assignment of the lower triangular matrix in the function and extend it to a batch of data using vmap
:
import torch
from torch.func import vmap
def fn(l):
L = torch.zeros(*l.shape[:-1], 2, 2)
L[..., [0, 1, 1], [0, 0, 1]] = l
return L
if __name__ == '__main__':
L1 = fn(torch.tensor([1., 2., 3.]))
print(L1)
Ls1 = fn(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
print(Ls1)
Ls2 = vmap(fn)(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
print(Ls2)
When called separately, it performs well. However, when using vmap
, an error is reported
L1 tensor([[1., 0.],
[2., 3.]])
Ls1 tensor([[[1., 0.],
[2., 3.]],
[[4., 0.],
[5., 6.]]])
Traceback (most recent call last):
File "/Users/frank/Projects/neural_lagrangian/neural_lagrangian/models/t.py", line 18, in <module>
Ls2 = vmap(fn)(torch.tensor([[1., 2., 3.], [4., 5., 6.]]))
File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
return _flat_vmap(
File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/Users/frank/opt/anaconda3/envs/neural_lagrangian/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/Users/frank/Projects/neural_lagrangian/neural_lagrangian/models/t.py", line 7, in fn
L[..., [0, 1, 1], [0, 0, 1]] = l
RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_. If said operator is being called inside the PyTorch framework, please file a bug report instead.
Due to the need to use jacrev
to calculate the Jacobian matrix in subsequent tasks, vmap
is necessary. I tried using the out-of-place version of the index_put
to solves this problem, but index_put
does not support ellipses.
May I ask if you have any good suggestions? Thank you very much!