Hi everyone,
I would like to use vmap with autograd and a nn.Module. Here is minimal working example similar to what I would like to implement.
import torch
from torch import nn, vmap
class Model(nn.Module):
def __init__(self):
super().__init__()
self._layers = nn.Sequential(nn.Linear(1,1))
def forward(self, x):
return self._layers(x)
net = Model()
batch_size = 5
x = torch.rand(batch_size, 1, requires_grad=True)
def vmap_grad(_x):
print(_x)
_y = net(_x)
return torch.autograd.grad(_y, _x, grad_outputs=torch.ones_like(_x))[0]
y_dx_vmap = vmap(vmap_grad)(x)
When I run this code snippet, I get the following error message:
*BatchedTensor(lvl=1, bdim=0, value=
tensor([[0.3807],
[0.0681],
[0.3440],
[0.6405],
[0.2308]], requires_grad=True)
)
Traceback (most recent call last):
File “/workspaces/app/test.py”, line 22, in
y_dx_vmap = vmap(vmap_grad)(x)
^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 436, in wrapped
return _flat_vmap(
^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 39, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 621, in _flat_vmap
batched_outputs = func(*batched_inputs, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/workspaces/app/test.py”, line 20, in vmap_grad
return torch.autograd.grad(_y, _x, grad_outputs=torch.ones_like(_x))[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/autograd/init.py”, line 319, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Could you please help me with this and tell me what I am doing wrong? On torch.func API Reference — PyTorch 2.1 documentation I read that
In general, you can transform over a function that calls a
torch.nn.Module
.
Also, I verified that the requires_grad attribute for the input tensor is True (see print-statement in vmap_grad() and terminal output). Therefore, unfortunately, I can not understand the terminal output.
System info:
- PyTorch version: 2.1.0.dev20230501+cpu
- Installed via pip
- Python version: 3.11
I would be very grateful if you could help me. Many thanks in advance!