TL;DR - torch.func.vjp
fails with functionalized nn.Module
objects in the torch.func
namespace due to mis-matched pytree structure.
Pytorch version: 2.1.0.dev20230519
CUDA version: 11.8
Hi All,
I’ve been trying to compute jacobian-vector and vector-jacobian products within the torch.func
namespace, however, when using torch.func.vjp
I seem to get the following error. I know that vjp
works for single tensor functions (see here: Jacobians, Hessians, hvp, vhp, and more: composing function transforms — PyTorch Tutorials 2.0.1+cu117 documentation)
RuntimeError: Expected pytree structure of cotangents to be the same as pytree structure of outputs to the function. cotangents: {'fc1.weight': *, 'fc1.bias': *, 'fc2.weight': *, 'fc2.bias': *}, primal output: *
A minimal reproducible example is below,
import torch
_ = torch.manual_seed(0)
from torch import nn, Tensor
from torch.func import vjp, jvp, functional_call, vmap
from functools import partial
class NET(nn.Module):
def __init__(self, num_input: int, num_hidden: int, num_output: int) -> None:
super(NET, self).__init__()
self.fc1 = nn.Linear(num_input, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_output)
self.af = nn.Tanh()
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.af(x)
x = self.fc2(x)
return x
#==============================================================================#
batch_size=100
num_input = 4
num_hidden = 64
num_output = 1
net = NET(num_input=num_input,
num_hidden=num_hidden,
num_output=num_output)
#define functional call here
def fcall(params, x):
return functional_call(net, params, x)
#functional inputs
params = dict(net.named_parameters())
x = torch.randn(batch_size, num_input)
#generate random vectors for jacobian-vector and vector-jacobian products
tangent_params = {k: torch.randn_like(p) for k, p in params.items()} #random tangent vector
tangent_x = torch.randn_like(x)
#generate functions for jacobian-vector and vector-jacobian products respectively
jvp_fn = lambda params, x, tangent_params, tangent_x: jvp(fcall, (params, x), (tangent_params, tangent_x))
_, vjp_fn = vjp(fcall, params, x)
out, jvp_out = vmap(jvp_fn, in_dims=(None,0,None,0))(params, x, tangent_params, tangent_x) #OK
print("jvp_out: ",jvp_out.shape)
out, vjp_out = vmap(vjp_fn, in_dims=(None,0))(tangent_params, tangent_x) # <<< FAILS HERE
print("vjp_out: ",vjp_out.shape)