How to vmap the calculation of a jacobian matrix?

Hi,

I would like to speed up my code to calculate the Jacobian matrix for a neural network outcome w.r.t. the network parameters. So far, my code looks as follows:

def autograd(input, params):
    O = torch.autograd.grad(input, params, torch.ones_like(input), allow_unused=True, retain_graph=True, create_graph=False)
    return O

def _compute_centered_jacobian(model,samples):
    """Computes O=d Psi/d Theta. """
    # get NN parameters
    l = tuple(filter(lambda p: p.requires_grad, model.parameters()))
    parameters = l
    # calculate log_probs and phases: I want to calculate the jacobian of both!
    log_probs, phases = model.log_probabilities(samples)
    jac_ampl = [torch.cat([j_.flatten() for j_ in autograd(0.5*log_probs[i], parameters) if j_!=None]) for i in range(log_probs.size(0))]
    jac_phase = [torch.cat([j_.flatten() for j_ in autograd(phases[i], parameters) if j_!= None]) for i in range(phases.size(0))]
    jac_ampl = torch.stack(jac_ampl) 
    jac_phase = torch.stack(jac_phase) 
    return jac_ampl, jac_phase, log_probs, phases

This works, but it is very slow. Furthermore, I can not use torch.autograd.functional.jacobian because I dont have an explicit function model.log_probabilities(network weights). Now I found out that it is possible to speed up calculations using vmap and tried something like this:

def autograd_vmap(inputs, params):
    def autograd(input):
        O = torch.autograd.grad(input, params, torch.ones_like(input), allow_unused=True, retain_graph=True, create_graph=False)
        return O
    out = vmap(autograd)(inputs)
    return out

However, this will raise

    O = torch.autograd.grad(input, params, torch.ones_like(input), allow_unused=True, retain_graph=True, create_graph=False)
  File "...", line ..., in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I would be happy about any help!

Hi @hl465,

Have a look at the torch.func package, most specifically torch.func.jacrev, which automatically vectorizes the Jacobican calculation!

Hi,
thanks a lot for your answer. I think for torch.func.jackrev the problem is the same as for the jacobian from pytorch, I.e. that I would need to have a function which has the network parameters as an input (not the case for me)…

You can pair torch.func.jacrev with torch.func.functional_call.

An example below,

net = Model(*args)
params = dict(net.named_parameters())

def func(params, x):
  return torch.func.functional_call(net, params, x)

jacobian = torch.func.jacrev(func, argnums=(0))(params, x) #single sample

jacobian = torch.func.vmap(torch.func.jacrev(func, argnums=(0)), in_dims=(None, 0))(params, x) #multiple samples


Thanks so muchfor your help!! I think I start to understand. So is the „func“ argument in functional_call my model.log_probabilities, and x=samples?

It should be,

def func(params, x):
  return torch.func.functional_call(net, params, x)

I re-wrote func by mistake!

Thanks :slight_smile: Now I am getting a new error:

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

but I never call .item()

Perhaps, your function is calling .item() under the hood. Can you share a minimal reproducible script?

Hi, sorry for the delay. Here is the example that runs into the error. I have replaced vmap(jacobian) with torch.autograd.functional.jacobian(..., vectorize=True) which does essentially the same.

import os
import time

import numpy as np
import torch
from torch import Tensor
from torch import nn
from model import Model

from functorch import jacrev, vmap, make_functional, grad


def compute_centered_jacobian(model,samples):
    # Computes real and imaginary parts of O=d Psi/d Theta. 
    func, parameters = make_functional(model)
    def func_(*params):
        return func(params, samples)
    #jacrev(func)(parameters, samples[0])
    #jac = vmap(jacrev(func), in_dims=(None,0))(parameters, samples)
    #print("end------")
    print(parameters[0][0])
    jac = torch.autograd.functional.jacobian(func_, parameters, vectorize=True)
    jac_ampl = jac[0]
    print(jac_ampl[0][0])
    print("----------")
    jac_phase = jac[1]
    print(jac_phase[0][0])
    jac_ampl = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_ampl)], axis=-1)
    jac_phase = torch.cat([it.reshape(it.size(0),-1) for it in list(jac_phase)], axis=-1)
    with torch.no_grad():
        jac_ampl -= jac_ampl.mean(axis=0)
        jac_phase -= jac_phase.mean(axis=0)
    return jac_ampl, jac_phase


def _compute_gradient_with_curvature(Tinv, E, O):
    n_samples = Tinv.size(0)
    #TinvE = solve_linear_problem(T, E, 1e-4) #/n_samples
    TinvE = torch.mv(Tinv, E) #/n_samples
    δ = torch.einsum("ij,j", O.t(),TinvE) #/n_samples
    return δ

def compute_gradient_with_curvature(Ore, Oim, E, model,**kwargs):
    # The following is an implementation of Eq. (6.22) (without the minus sign) in "Quantum Monte
    # Carlo Approaches for Correlated Systems" by F.Becca & S.Sorella.
    #
    # Basically, what we need to compute is `2/N · Re[E*·(O - ⟨O⟩)]`, where `E` is a `Nx1` vector
    # of complex-numbered local energies, `O` is a `NxM` matrix of logarithmic derivatives, `⟨O⟩` is
    # a `1xM` vector of mean logarithmic derivatives, and `N` is the number of Monte Carlo samples.
    #
    # Now, `Re[a*·b] = Re[a]·Re[b] - Im[a*]·Im[b] = Re[a]·Re[b] +
    # Im[a]·Im[b]` that's why there are no conjugates or minus signs in
    # the code.
    E = (E-E.mean()) 
    T = torch.einsum("ij,jk", Ore, Ore.t())+torch.einsum("ij,jk", Oim, Oim.t())
    Tinv = torch.linalg.pinv(T,rtol=1e-12)
    δ = _compute_gradient_with_curvature(Tinv, E.real, Ore)+_compute_gradient_with_curvature(Tinv, E.imag, Oim)
    return δ
@torch.no_grad()
def apply_grads(model,grad):
    """ Assigns the calculated gradients to the model parameter gradients. """
    i = 0
    for p in filter(lambda x: x.requires_grad, model.parameters()):
        n = p.numel()
        if p.grad is not None:
            p.grad.copy_(grad[i : i + n].view(p.size()))
        else:
            print("gradient = None. Please check whats going wrong!")
            p.grad = grad[i : i + n].view(p.size())
        i += 1

def run_sr(model, E, samples, optimizer, scheduler=None):
    """ Runs a minSR step. """
    print(E.mean())
    # Campute the real and imaginary part of the jacobian matrix. 
    Ore, Oim = compute_centered_jacobian(model, samples)
    # Compute the gradients of the NN parameters from that.
    grads = compute_gradient_with_curvature(Oim, Ore, E, model)
    print(grads)
    # Assign the calculated gradients and make one optimization step.
    apply_grads(model,grads)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()

In this case I get RuntimeError: Batching rule not implemented for aten::item. We could not generate a fallback.

Can you share the full stacktrace?