I would like to save my BatchedTensor (the tensor type that’s produced by functorch.vmap) to a file for debugging my application. However, a RuntimeError is raised whenever I call torch.save to the BatchedTensor: RuntimeError: Cannot access data pointer of Tensor that doesn't have storage.
import torch
import functorch
def func(x: torch.Tensor) -> torch.Tensor:
# some function where we want to debug closely here
y = 2 * x
print(x, y) # to show x & y are BatchedTensors
# try to save the tensors (the error occurs here)
torch.save((x, y), "somefile.pt")
return y
x = torch.randn((4, 5))
y = functorch.vmap(func)(x)
Thanks for the code snippet.
I don’t know if there is a public and supported way to save the internal and intermediate tensors, but this code using internal calls should work:
def func(x: torch.Tensor) -> torch.Tensor:
# some function where we want to debug closely here
y = 2 * x
print(x, y) # to show x & y are BatchedTensors
# try to save the tensors (the error occurs here)
vx = torch._C._functorch.get_unwrapped(x)
vy = torch._C._functorch.get_unwrapped(y)
torch.save((vx, vy), "somefile.pt")
return y
x = torch.randn((4, 5))
y = functorch.vmap(func)(x)
Note that these calls can easily break since they are internal and no backwards-compatibility is guaranteed. @richard might know more about this use case.
Thanks, @ptrblck! The function get_unwrapped is very helpful in debugging within vmapped-function. It would be useful if there’s a public API for this kind of function.
Yeah, I’ll let @richard add a comment here in case this is planned as a future debugging tool (or maybe there is another already public API that I’m missing).