Hi! For my current project, I very much need to escape tensors from a vmapped function (i.e. to save them to a “global” variable instead of returning them). There’s a paragraph of the documentation saying that this is not supported: UX Limitations — PyTorch 2.9 documentation. However, using torch.func.debug_unwrap on the tensor to escape makes it work:
import torch
dot_products = None
def fn(x, y):
global dot_products
dot_products = torch.func.debug_unwrap(x @ y)
torch.vmap(fn, out_dims=None)(torch.randn(32, 10), torch.randn(32, 10))
print(dot_products) # As expected: shape [32], not a BatchedTensor
Note that in this example I could have very easily just returned the result instead of having to escape it, but in my use-case, properly returning the result is extremely complex, so I’d rather escape the (unwrapped) tensor.
My question is: is there any actual risk or performance issue when doing that?