Using debug_unwrap to escape vmap

Hi! For my current project, I very much need to escape tensors from a vmapped function (i.e. to save them to a “global” variable instead of returning them). There’s a paragraph of the documentation saying that this is not supported: UX Limitations — PyTorch 2.9 documentation. However, using torch.func.debug_unwrap on the tensor to escape makes it work:

import torch

dot_products = None

def fn(x, y):
    global dot_products
    dot_products = torch.func.debug_unwrap(x @ y)

torch.vmap(fn, out_dims=None)(torch.randn(32, 10), torch.randn(32, 10))
print(dot_products)  # As expected: shape [32], not a BatchedTensor

Note that in this example I could have very easily just returned the result instead of having to escape it, but in my use-case, properly returning the result is extremely complex, so I’d rather escape the (unwrapped) tensor.

My question is: is there any actual risk or performance issue when doing that?