Hi, I would like to optimize the model parameters based on the loss computed with respect to the Jacobian of the model output with respect to the input as the following code:
import torch
from abc import ABC
from functorch import vmap, jacrev, make_functional
class WaveFunction(torch.nn.Module, ABC):
def __init__(self):
super().__init__()
layers = [torch.nn.Linear(10, 10), torch.nn.Tanh(), torch.nn.Linear(10, 10)]
self.model = torch.nn.ModuleList(layers)
def forward(self, z):
out = z
for net in self.model:
out = net(out)
return out
def jacobian(self, z):
functional, params = make_functional(self)
compute_batch_jacobian = vmap(jacrev(functional, argnums=1),
in_dims=(None, 0), randomness='same')
batch_jacobian = compute_batch_jacobian(params, z)
return batch_jacobian
w = WaveFunction()
z = torch.randn((3, 10), requires_grad=True) # Set requires_grad to True for input
# Compute the Jacobian with respect to the input
jacobian = w.jacobian(z)
# Define your loss based on the Jacobian
loss = torch.mean(jacobian)
# Optimize the model parameters based on the loss
optimizer = torch.optim.SGD(w.parameters(), lr=0.01)
# Perform the optimization step
optimizer.zero_grad() # Zero out the gradients to avoid accumulation
loss.backward() # Compute gradients with respect to the model parameters
optimizer.step() # Update the model parameters
# Now, the model parameters have been updated based on the loss computed with respect to the Jacobian
But the code above does not work because all the gradient of the model parameters remain None after the backward function. Probably because the jacobian calculated using functorch considered the model parameters as constants?