Hello everyone, I’m struggling to understand why vmap is not working, attached you can find a minimal code that reproduces the error:
import torch
def dummy_func(amount):
test = torch.eye(4,4)
test[2,3] = amount
return test
if __name__ == "__main__":
test_input = torch.tensor([1., 2., 3., 4., 5.])
test_output = torch.vmap(dummy_func)(test_input)
The error you should get is the following:
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
Any help would be very appreciated!