Torch.vmap in-place operation error

Hi all,

I have a problem with torch.vmap
As clarified in the docs, torch.vmap does not support in-place operations when the number of elements of the tensor decreases.
In my case, it does not seem like that is the case at first glance, but I am clearly missing something.
Additionally, all the suggested workarounds I found googling are not working.
Any help would be very appreciated!

import torch

def faulty_function(x, y):
mat = torch.zeros(x.shape[0], 3, 3)

mat[..., 0, 0] = x * y    
return mat

batched_x = torch.randn(10, 1)
batched_y = torch.randn(10, 1)

result = torch.vmap(fixed_function)(batched_x, batched_y)

With this example, I get the following error:

RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor other in extra_args that has more elements than self. This happened due to other being vmapped over but self not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.