RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function

TL;DR - torch.func.vjp fails with functionalized nn.Module objects in the torch.func namespace due to mis-matched pytree structure.

Pytorch version: 2.1.0.dev20230519
CUDA version: 11.8

Hi All,

I’ve been trying to compute jacobian-vector and vector-jacobian products within the torch.func namespace, however, when using torch.func.vjp I seem to get the following error. I know that vjp works for single tensor functions (see here: Jacobians, Hessians, hvp, vhp, and more: composing function transforms — PyTorch Tutorials 2.0.1+cu117 documentation)

RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {'fc1.weight': *, 'fc1.bias': *, 'fc2.weight': *, 'fc2.bias': *}, primal output: *

A minimal reproducible example is below,

import torch
_ = torch.manual_seed(0)
from torch import nn, Tensor
from torch.func import vjp, jvp, functional_call, vmap
from functools import partial

class NET(nn.Module):

  def __init__(self, num_input: int, num_hidden: int, num_output: int) -> None:
    super(NET, self).__init__()
    self.fc1 = nn.Linear(num_input, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_output) = nn.Tanh()
  def forward(self, x: Tensor) -> Tensor:
    x = self.fc1(x)
    x =
    x = self.fc2(x)
    return x


num_input = 4
num_hidden = 64
num_output = 1

net = NET(num_input=num_input,

#define functional call here
def fcall(params, x):
  return functional_call(net, params, x)

#functional inputs 
params = dict(net.named_parameters()) 
x = torch.randn(batch_size, num_input)

#generate random vectors for jacobian-vector and vector-jacobian products
tangent_params = {k: torch.randn_like(p) for k, p in params.items()} #random tangent vector
tangent_x = torch.randn_like(x)

#generate functions for jacobian-vector and vector-jacobian products respectively
jvp_fn = lambda params, x, tangent_params, tangent_x: jvp(fcall, (params, x), (tangent_params, tangent_x))
_, vjp_fn = vjp(fcall, params, x)

out, jvp_out = vmap(jvp_fn, in_dims=(None,0,None,0))(params, x, tangent_params, tangent_x) #OK
print("jvp_out: ",jvp_out.shape)
out, vjp_out = vmap(vjp_fn, in_dims=(None,0))(tangent_params, tangent_x) # <<< FAILS HERE
print("vjp_out: ",vjp_out.shape)