Computing jacobian column-wise with Pytorch, and without loop

In Pytorch, I have a multi-dimension solution stored on each column of a tensor u, and for the training of the PINN I’m developing, I use this function to compute the derivative column-wise.

def dt(self, u, t):
    # return batch_jacobian(u, t)
    N = u.shape[1]
    u_t = [] # initializing the time derivatives
    # for each columns (i), we compute the time derivative ui_t
    for i in range(N):
        ui_t = torch.autograd.grad(u[:, i], t,
                                   grad_outputs=torch.ones_like(u[:, i]),
                                   retain_graph=True,
                                   create_graph=True,
                                   allow_unused=True)[0]
        u_t.append(ui_t) # time derivatives are stored in u_t
    u_t = torch.cat(tuple(u_t), dim=1) # we concateneate all the derivatives
    return u_t

but it involves a for loop that I would like to remove.

In my research about it, I found this topic with the same issue, but I wasn’t able to fix my code with the proposed solutions.

Here is the definition of the functional that is used for the training of the neural network :

def functional_f(self, t, theta):
    self.A = torch.tensor(theta[:,1:], requires_grad=False).float().to(self.device)
    self.mu = torch.tensor(theta[:,0], requires_grad=False).float().to(self.device)

    u = self.dnn(t)
    u_t = self.dt(u, t)
    ode = u_t - self.renorm*(self.mu + (self.A @ torch.exp(u).T).T)
    return ode

the ODE system to be solved is ∂t ui = µ + A⋅exp(ui), where µ ∈ ℝ^N and A ∈ ℝ^(N,N) are stored in the matrix theta, are fixed parameters involved in the ODE.

How could I implement the function dt so that there is no more loop in it ?
I also tried the decorator jit but without success either…

(nb : this is a duplicate of my stack overflow post that did not get an answer yet…)

Hi @ergias,

In the other discuss link you shared, it referenced my answer of using the functorch library, however, with the release of PyTorch2.0+, the functorch library is now within torch itself in the torch.func namespace.

For the most part things have stayed the same, except instead of functionalizing your net directly you need to call torch.func.functional_call(net, params, *inputs) and pass the parameters of your network as a dictionary.

For example,

net = Net(*args, **kwargs)
from torch.func import functional_call, jacrev, vmap

def fcall(params, x):
  return functional_call(net, params, x)

per_sample_jacobian_wrt_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x)

If you want to remove your loop you’ll need to do a simpler trick what I showed above. If you have more outputs to your function than inputs, you may want to use forward-mode AD via torch.func.jacfwd to get better performance.

Also, don’t use the torch.jit() decorator it really won’t do much, it’s best to use torch.compile() instead, but this will only give minimal improvements as well as torch.compile was only released with PyTorch2.0 recently.

Furthermore, be careful with re-wrapping your tensor, e.g. self.A and self.mu as that’ll break any gradient flow for those parameters. You could also directly initialize on the GPU via torch.as_tensor(data, dtype=torch.float, device=self.device), which may be slightly quicker too.

Hi @AlphaBetaGamma96 , thanks for your answer !

However, I was not able to use the tips you gave in the code.
If I have well understood :

  • I need to convert the method self.dt into a functional_call, that takes the same arguments
  • then the call of vmap to this functional_call will return the gradient computed on each column, but more efficiently.

Here is what I have done :

def dt(self, u, t):
    nb_species = u.shape[1]
    u_t = [] # initializing the time derivatives
    
    # for each species (i), we compute the time derivative ui_t
    for i in range(nb_species):
        ui_t = torch.autograd.grad(u[:, i], t,
                                    grad_outputs=torch.ones_like(u[:, i]),
                                    retain_graph=True,
                                    create_graph=True,
                                    allow_unused=True)[0]
        u_t.append(ui_t) # time derivatives are stored in u_t
    u_t = torch.cat(tuple(u_t), dim=1) # we concateneate all the derivatives
    return u_t

def dt_call(self, u, t):
    return torch.func.functional_call(self.dt, u, t)

def functional_f(self, t, theta):
    self.A = torch.tensor(theta[:,1:], requires_grad=False).float().to(self.device)
    self.mu = torch.tensor(theta[:,0], requires_grad=False).float().to(self.device)

    u = self.dnn(t)
    u_t = torch.func.vmap(torch.func.jacrev(self.dt_call, argnums=(0)), in_dims=(None, 0))(u, t)
    # u_t = self.dt(u, t)
    ode = u_t - self.renorm*(self.mu + (self.A @ torch.exp(u).T).T)
    return ode

I did not really catch what was the argument in_dims argument of the function vmap, and I think it is the key point because I get this error : ValueError: Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, but got <class 'torch.Tensor'> (I also tried to change the values of the argument, but it raises a similar error ValueError: vmap: Expected all tensors to have the same size in the mapped dimension, got sizes [3, 21] for the mapped dimension).

Hi @ergias,

In my example, the functional_call is applied to nn.Module objects, in your case you have just a function and you can simply pass it to torch.func.vmap. I’d recommend giving the docs for torch.func.grad and torch.func.vmap a read so you can apply them to your particular use-case, as you don’t have a minimal reproducible example it’s hard to see how everything fits together.

If you can share a minimal reproducible example, I can have a look.