Optimize the model parameters based on the loss computed with respect to the Jacobian of the model output with respect to the input

Hi, I would like to optimize the model parameters based on the loss computed with respect to the Jacobian of the model output with respect to the input as the following code:

import torch
from abc import ABC
from functorch import vmap, jacrev, make_functional

class WaveFunction(torch.nn.Module, ABC):
    def __init__(self):
        super().__init__()
        layers = [torch.nn.Linear(10, 10), torch.nn.Tanh(), torch.nn.Linear(10, 10)]
        self.model = torch.nn.ModuleList(layers)

    def forward(self, z):
        out = z
        for net in self.model:
            out = net(out)
        return out

    def jacobian(self, z):
        functional, params = make_functional(self)
        compute_batch_jacobian = vmap(jacrev(functional, argnums=1),
                                      in_dims=(None, 0), randomness='same')
        batch_jacobian = compute_batch_jacobian(params, z)
        return batch_jacobian

w = WaveFunction()
z = torch.randn((3, 10), requires_grad=True)  # Set requires_grad to True for input

# Compute the Jacobian with respect to the input
jacobian = w.jacobian(z)

# Define your loss based on the Jacobian
loss = torch.mean(jacobian)

# Optimize the model parameters based on the loss
optimizer = torch.optim.SGD(w.parameters(), lr=0.01)

# Perform the optimization step
optimizer.zero_grad()  # Zero out the gradients to avoid accumulation
loss.backward()  # Compute gradients with respect to the model parameters
optimizer.step()  # Update the model parameters

# Now, the model parameters have been updated based on the loss computed with respect to the Jacobian

But the code above does not work because all the gradient of the model parameters remain None after the backward function. Probably because the jacobian calculated using functorch considered the model parameters as constants?

I found the solution to this issue. In the latest version of pytorch where we can use torch.func, the following code works well:

import torch
from abc import ABC
from torch.func import vmap, jacrev, functional_call
class WaveFunction(torch.nn.Module, ABC):
    def __init__(self):
        super().__init__()
        layers = [torch.nn.Linear(10, 10), torch.nn.Tanh(), torch.nn.Linear(10, 10)]
        self.model = torch.nn.ModuleList(layers)

    def forward(self, z):
        out = z
        for net in self.model:
            out = net(out)
        return out

    def jacobian(self, z):
        params = dict(self.named_parameters())
        compute_batch_jacobian = vmap(jacrev(self.fmodel, argnums=1),
                                      in_dims=(None, 0), randomness='same')
        batch_jacobian = compute_batch_jacobian(params, z)
        return batch_jacobian

    def fmodel(self, params, inputs):  # functional version of model
        return functional_call(self, params, inputs)

model = WaveFunction()
z = torch.randn((3, 10), requires_grad=True)  # Set requires_grad to True for input
result = model.jacobian(z)
loss = torch.mean(result)
loss.backward()
print()