Vmap over autograd.grad of a nn.Module

Hi everyone,

I would like to use vmap with autograd and a nn.Module. Here is minimal working example similar to what I would like to implement.

import torch
from torch import nn, vmap

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self._layers = nn.Sequential(nn.Linear(1,1))

    def forward(self, x):
        return self._layers(x)

net = Model()

batch_size = 5
x = torch.rand(batch_size, 1, requires_grad=True)

def vmap_grad(_x):
    print(_x)
    _y = net(_x)
    return torch.autograd.grad(_y, _x, grad_outputs=torch.ones_like(_x))[0]

y_dx_vmap = vmap(vmap_grad)(x)

When I run this code snippet, I get the following error message:

*BatchedTensor(lvl=1, bdim=0, value=
tensor([[0.3807],
[0.0681],
[0.3440],
[0.6405],
[0.2308]], requires_grad=True)
)
Traceback (most recent call last):
File “/workspaces/app/test.py”, line 22, in
y_dx_vmap = vmap(vmap_grad)(x)
^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 436, in wrapped
return _flat_vmap(
^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 39, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/_functorch/vmap.py”, line 621, in _flat_vmap
batched_outputs = func(*batched_inputs, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/workspaces/app/test.py”, line 20, in vmap_grad
return torch.autograd.grad(_y, _x, grad_outputs=torch.ones_like(_x))[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/usr/local/lib/python3.11/site-packages/torch/autograd/init.py”, line 319, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Could you please help me with this and tell me what I am doing wrong? On torch.func API Reference — PyTorch 2.1 documentation I read that

In general, you can transform over a function that calls a torch.nn.Module .

Also, I verified that the requires_grad attribute for the input tensor is True (see print-statement in vmap_grad() and terminal output). Therefore, unfortunately, I can not understand the terminal output.

System info:

  • PyTorch version: 2.1.0.dev20230501+cpu
  • Installed via pip
  • Python version: 3.11

I would be very grateful if you could help me. Many thanks in advance!

Hi @david-anton,

You should replace torch.autograd.grad with torch.func.jacrev and you’ll need to use torch.func.functional_call when calling you nn.Module object.

For example,

import torch
from torch import nn
from torch.func import jacrev, vmap, functional_call

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self._layers = nn.Sequential(nn.Linear(1,1))

    def forward(self, x):
        return self._layers(x)

net = Model()

batch_size = 5
x = torch.rand(batch_size, 1, requires_grad=True)

def vmap_grad(_x):
    print(_x)
    _y = net(_x)
    return torch.autograd.grad(_y, _x, grad_outputs=torch.ones_like(_x))[0]

#torch.func requires a functionalized call of your nn.Module
def fcall(params, x):
  return functional_call(net, params, x)
    
params = dict(net.named_parameters())
    
y_dx_vmap = vmap(jacrev(fcall, argnums=(1)), in_dims=(None, 0))(params, x)

print(x)
print(y_dx_vmap)

print(net._layers[0].weight) #derivative is just the weight of layer (here's a check)

This script returns,

tensor([[0.5876],
        [0.9047],
        [0.2497],
        [0.6386],
        [0.3418]], requires_grad=True)
tensor([[[0.0340]],

        [[0.0340]],

        [[0.0340]],

        [[0.0340]],

        [[0.0340]]], grad_fn=<ViewBackward0>)
Parameter containing:
tensor([[0.0340]], requires_grad=True)