Save BatchedTensor to a pickle file

I would like to save my BatchedTensor (the tensor type that’s produced by functorch.vmap) to a file for debugging my application. However, a RuntimeError is raised whenever I call torch.save to the BatchedTensor: RuntimeError: Cannot access data pointer of Tensor that doesn't have storage.

How can I save a BatchedTensor?

Could you share a minimal and executable code snippet returning the BatchedTensor object as I’m seeing plain tensors while running the vmap tutorial.

Here is the minimal executable code:

import torch
import functorch

def func(x: torch.Tensor) -> torch.Tensor:
    # some function where we want to debug closely here
    y = 2 * x
    print(x, y)  # to show x & y are BatchedTensors
    # try to save the tensors (the error occurs here)
    torch.save((x, y), "somefile.pt")
    return y

x = torch.randn((4, 5))
y = functorch.vmap(func)(x)

Thanks for the code snippet.
I don’t know if there is a public and supported way to save the internal and intermediate tensors, but this code using internal calls should work:

def func(x: torch.Tensor) -> torch.Tensor:
    # some function where we want to debug closely here
    y = 2 * x
    print(x, y)  # to show x & y are BatchedTensors
    # try to save the tensors (the error occurs here)
    vx = torch._C._functorch.get_unwrapped(x)
    vy = torch._C._functorch.get_unwrapped(y)
    torch.save((vx, vy), "somefile.pt")
    return y

x = torch.randn((4, 5))
y = functorch.vmap(func)(x)

Note that these calls can easily break since they are internal and no backwards-compatibility is guaranteed. @richard might know more about this use case.

1 Like

Thanks, @ptrblck! The function get_unwrapped is very helpful in debugging within vmapped-function. It would be useful if there’s a public API for this kind of function.

Yeah, I’ll let @richard add a comment here in case this is planned as a future debugging tool (or maybe there is another already public API that I’m missing).